forked from mindspore-Ecosystem/mindspore
gnmt add hub_config
This commit is contained in:
parent
c52a4381dd
commit
2f2a624281
|
@ -135,11 +135,13 @@ The GNMT network script and code result are as follows:
|
||||||
│ ├──lr_scheduler.py // Learning rate scheduler.
|
│ ├──lr_scheduler.py // Learning rate scheduler.
|
||||||
│ ├──optimizer.py // Optimizer.
|
│ ├──optimizer.py // Optimizer.
|
||||||
├── scripts
|
├── scripts
|
||||||
│ ├──run_distributed_train_ascend.sh // shell script for distributed train on ascend.
|
│ ├──run_distributed_train_ascend.sh // Shell script for distributed train on ascend.
|
||||||
│ ├──run_standalone_eval_ascend.sh // shell script for standalone eval on ascend.
|
│ ├──run_standalone_eval_ascend.sh // Shell script for standalone eval on ascend.
|
||||||
│ ├──run_standalone_train_ascend.sh // shell script for standalone eval on ascend.
|
│ ├──run_standalone_train_ascend.sh // Shell script for standalone eval on ascend.
|
||||||
├── create_dataset.py // dataset preparation.
|
├── create_dataset.py // Dataset preparation.
|
||||||
├── eval.py // Infer API entry.
|
├── eval.py // Infer API entry.
|
||||||
|
├── export.py // Export checkpoint file into air models.
|
||||||
|
├── mindspore_hub_conf.py // Hub config.
|
||||||
├── requirements.txt // Requirements of third party package.
|
├── requirements.txt // Requirements of third party package.
|
||||||
├── train.py // Train API entry.
|
├── train.py // Train API entry.
|
||||||
```
|
```
|
||||||
|
|
|
@ -65,15 +65,10 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
for param in params:
|
for param in params:
|
||||||
value = param.data
|
value = param.data
|
||||||
name = param.name
|
weights_name = param.name
|
||||||
if name not in weights:
|
if weights_name not in weights:
|
||||||
raise ValueError(f"{name} is not found in weights.")
|
raise ValueError(f"{weights_name} is not found in weights.")
|
||||||
with open("weight_after_deal.txt", "a+") as f:
|
|
||||||
weights_name = name
|
|
||||||
f.write(weights_name)
|
|
||||||
f.write("\n")
|
|
||||||
if isinstance(value, Tensor):
|
if isinstance(value, Tensor):
|
||||||
print(name, value.asnumpy().shape)
|
|
||||||
if weights_name in weights:
|
if weights_name in weights:
|
||||||
assert weights_name in weights
|
assert weights_name in weights
|
||||||
if isinstance(weights[weights_name], Parameter):
|
if isinstance(weights[weights_name], Parameter):
|
||||||
|
@ -91,7 +86,7 @@ if __name__ == '__main__':
|
||||||
else:
|
else:
|
||||||
print("weight not found in checkpoint: " + weights_name)
|
print("weight not found in checkpoint: " + weights_name)
|
||||||
param.set_data(zero_weight(value.asnumpy().shape))
|
param.set_data(zero_weight(value.asnumpy().shape))
|
||||||
f.close()
|
|
||||||
print(" | Load weights successfully.")
|
print(" | Load weights successfully.")
|
||||||
tfm_infer = GNMTInferCell(tfm_model)
|
tfm_infer = GNMTInferCell(tfm_model)
|
||||||
tfm_infer.set_train(False)
|
tfm_infer.set_train(False)
|
||||||
|
|
|
@ -0,0 +1,40 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""hub config."""
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
|
||||||
|
from config import GNMTConfig
|
||||||
|
from src.gnmt_model import GNMTNetworkWithLoss, GNMT
|
||||||
|
|
||||||
|
|
||||||
|
def get_config(config):
|
||||||
|
config = GNMTConfig.from_json_file(config)
|
||||||
|
config.compute_type = mstype.float16
|
||||||
|
config.dtype = mstype.float32
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def create_network(name, *args, **kwargs):
|
||||||
|
"""create gnmt network."""
|
||||||
|
if name == "gnmt":
|
||||||
|
if "config" in kwargs:
|
||||||
|
config = get_config(kwargs["config"])
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Please make sure the configuration file path is correct")
|
||||||
|
is_training = kwargs.get("is_training", False)
|
||||||
|
if is_training:
|
||||||
|
return GNMTNetworkWithLoss(config, is_training=is_training, *args)
|
||||||
|
return GNMT(config, *args)
|
||||||
|
raise NotImplementedError(f"{name} is not implemented in the repo")
|
|
@ -172,6 +172,7 @@ class BeamSearchDecoder(nn.Cell):
|
||||||
max_decode_length=64,
|
max_decode_length=64,
|
||||||
sos_id=2,
|
sos_id=2,
|
||||||
eos_id=3,
|
eos_id=3,
|
||||||
|
is_using_while=False,
|
||||||
compute_type=mstype.float32):
|
compute_type=mstype.float32):
|
||||||
super(BeamSearchDecoder, self).__init__()
|
super(BeamSearchDecoder, self).__init__()
|
||||||
|
|
||||||
|
@ -185,6 +186,7 @@ class BeamSearchDecoder(nn.Cell):
|
||||||
self.cov_penalty_factor = cov_penalty_factor
|
self.cov_penalty_factor = cov_penalty_factor
|
||||||
self.max_decode_length = max_decode_length
|
self.max_decode_length = max_decode_length
|
||||||
self.decoder = decoder
|
self.decoder = decoder
|
||||||
|
self.is_using_while = is_using_while
|
||||||
|
|
||||||
self.add = P.TensorAdd()
|
self.add = P.TensorAdd()
|
||||||
self.expand = P.ExpandDims()
|
self.expand = P.ExpandDims()
|
||||||
|
@ -215,6 +217,11 @@ class BeamSearchDecoder(nn.Cell):
|
||||||
self.gather_nd = P.GatherNd()
|
self.gather_nd = P.GatherNd()
|
||||||
|
|
||||||
self.start_ids = Tensor(np.full([batch_size * beam_width, 1], sos_id), mstype.int32)
|
self.start_ids = Tensor(np.full([batch_size * beam_width, 1], sos_id), mstype.int32)
|
||||||
|
if self.is_using_while:
|
||||||
|
self.start = Tensor(0, dtype=mstype.int32)
|
||||||
|
self.init_seq = Tensor(np.full([batch_size, beam_width, self.max_decode_length], sos_id),
|
||||||
|
mstype.int32)
|
||||||
|
else:
|
||||||
self.init_seq = Tensor(np.full([batch_size, beam_width, 1], sos_id), mstype.int32)
|
self.init_seq = Tensor(np.full([batch_size, beam_width, 1], sos_id), mstype.int32)
|
||||||
|
|
||||||
init_scores = np.tile(np.array([[0.] + [-INF] * (beam_width - 1)]), [batch_size, 1])
|
init_scores = np.tile(np.array([[0.] + [-INF] * (beam_width - 1)]), [batch_size, 1])
|
||||||
|
@ -259,7 +266,7 @@ class BeamSearchDecoder(nn.Cell):
|
||||||
self.sub = P.Sub()
|
self.sub = P.Sub()
|
||||||
|
|
||||||
def one_step(self, cur_input_ids, enc_states, enc_attention_mask, state_log_probs,
|
def one_step(self, cur_input_ids, enc_states, enc_attention_mask, state_log_probs,
|
||||||
state_seq, state_length, decoder_hidden_state=None, accu_attn_scores=None,
|
state_seq, state_length, idx=None, decoder_hidden_state=None, accu_attn_scores=None,
|
||||||
state_finished=None):
|
state_finished=None):
|
||||||
"""
|
"""
|
||||||
Beam search one_step output.
|
Beam search one_step output.
|
||||||
|
@ -359,6 +366,12 @@ class BeamSearchDecoder(nn.Cell):
|
||||||
self.hidden_size))
|
self.hidden_size))
|
||||||
|
|
||||||
# update state_seq
|
# update state_seq
|
||||||
|
if self.is_using_while:
|
||||||
|
state_seq_new = self.cast(seq, mstype.float32)
|
||||||
|
word_indices_fp32 = self.cast(word_indices, mstype.float32)
|
||||||
|
state_seq_new[:, :, idx] = word_indices_fp32
|
||||||
|
state_seq = self.cast(state_seq_new, mstype.int32)
|
||||||
|
else:
|
||||||
state_seq = self.concat((seq, self.expand(word_indices, -1)))
|
state_seq = self.concat((seq, self.expand(word_indices, -1)))
|
||||||
|
|
||||||
cur_input_ids = self.reshape(word_indices, (-1, 1))
|
cur_input_ids = self.reshape(word_indices, (-1, 1))
|
||||||
|
@ -388,11 +401,22 @@ class BeamSearchDecoder(nn.Cell):
|
||||||
decoder_hidden_state = self.decoder_hidden_state
|
decoder_hidden_state = self.decoder_hidden_state
|
||||||
accu_attn_scores = self.accu_attn_scores
|
accu_attn_scores = self.accu_attn_scores
|
||||||
|
|
||||||
|
if not self.is_using_while:
|
||||||
for _ in range(self.max_decode_length + 1):
|
for _ in range(self.max_decode_length + 1):
|
||||||
cur_input_ids, state_log_probs, state_seq, state_length, decoder_hidden_state, accu_attn_scores, \
|
cur_input_ids, state_log_probs, state_seq, state_length, decoder_hidden_state, accu_attn_scores, \
|
||||||
state_finished = self.one_step(cur_input_ids, enc_states, enc_attention_mask, state_log_probs,
|
state_finished = self.one_step(cur_input_ids, enc_states, enc_attention_mask, state_log_probs,
|
||||||
state_seq, state_length, decoder_hidden_state, accu_attn_scores,
|
state_seq, state_length, None, decoder_hidden_state, accu_attn_scores,
|
||||||
state_finished)
|
state_finished)
|
||||||
|
else:
|
||||||
|
idx = self.start + 1
|
||||||
|
ends = self.start + self.max_decode_length + 1
|
||||||
|
while idx < ends:
|
||||||
|
cur_input_ids, state_log_probs, state_seq, state_length, decoder_hidden_state, accu_attn_scores, \
|
||||||
|
state_finished = self.one_step(cur_input_ids, enc_states, enc_attention_mask, state_log_probs,
|
||||||
|
state_seq, state_length, idx, decoder_hidden_state, accu_attn_scores,
|
||||||
|
state_finished)
|
||||||
|
idx = idx + 1
|
||||||
|
|
||||||
# add length penalty scores
|
# add length penalty scores
|
||||||
penalty_len = self.length_penalty(state_length)
|
penalty_len = self.length_penalty(state_length)
|
||||||
# return penalty_len
|
# return penalty_len
|
||||||
|
@ -408,6 +432,9 @@ class BeamSearchDecoder(nn.Cell):
|
||||||
gather_indices = self.concat((self.expand(self.batch_ids, -1), self.expand(top_beam_indices, -1)))
|
gather_indices = self.concat((self.expand(self.batch_ids, -1), self.expand(top_beam_indices, -1)))
|
||||||
# sort sequence and attention scores
|
# sort sequence and attention scores
|
||||||
predicted_ids = self.gather_nd(state_seq, gather_indices)
|
predicted_ids = self.gather_nd(state_seq, gather_indices)
|
||||||
|
if not self.is_using_while:
|
||||||
predicted_ids = predicted_ids[:, 0:1, 1:(self.max_decode_length + 1)]
|
predicted_ids = predicted_ids[:, 0:1, 1:(self.max_decode_length + 1)]
|
||||||
|
else:
|
||||||
|
predicted_ids = predicted_ids[:, 0:1, :self.max_decode_length]
|
||||||
|
|
||||||
return predicted_ids
|
return predicted_ids
|
||||||
|
|
|
@ -23,7 +23,6 @@ from mindspore.common.tensor import Tensor
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore import context, Parameter
|
from mindspore import context, Parameter
|
||||||
from mindspore.train.model import Model
|
from mindspore.train.model import Model
|
||||||
from mindspore.train.serialization import load_checkpoint
|
|
||||||
|
|
||||||
from src.dataset import load_dataset
|
from src.dataset import load_dataset
|
||||||
from .gnmt import GNMT
|
from .gnmt import GNMT
|
||||||
|
@ -37,18 +36,6 @@ context.set_context(
|
||||||
reserve_class_name_in_scope=False)
|
reserve_class_name_in_scope=False)
|
||||||
|
|
||||||
|
|
||||||
def get_weight_and_variable(model_path, params):
|
|
||||||
print("model path is {}".format(model_path))
|
|
||||||
ms_ckpt = load_checkpoint(model_path)
|
|
||||||
with open("variable.txt", "w") as f:
|
|
||||||
for msname in ms_ckpt:
|
|
||||||
f.write(msname + "\n")
|
|
||||||
with open("weights.txt", "w") as f:
|
|
||||||
for param in params:
|
|
||||||
name = param.name
|
|
||||||
f.write(name + "\n")
|
|
||||||
|
|
||||||
|
|
||||||
class GNMTInferCell(nn.Cell):
|
class GNMTInferCell(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Encapsulation class of GNMT network infer.
|
Encapsulation class of GNMT network infer.
|
||||||
|
@ -92,20 +79,13 @@ def gnmt_infer(config, dataset):
|
||||||
use_one_hot_embeddings=False)
|
use_one_hot_embeddings=False)
|
||||||
|
|
||||||
params = tfm_model.trainable_params()
|
params = tfm_model.trainable_params()
|
||||||
get_weight_and_variable(config.existed_ckpt, params)
|
|
||||||
weights = load_infer_weights(config)
|
weights = load_infer_weights(config)
|
||||||
|
|
||||||
for param in params:
|
for param in params:
|
||||||
value = param.data
|
value = param.data
|
||||||
name = param.name
|
weights_name = param.name
|
||||||
if name not in weights:
|
if weights_name not in weights:
|
||||||
raise ValueError(f"{name} is not found in weights.")
|
raise ValueError(f"{weights_name} is not found in weights.")
|
||||||
with open("weight_after_deal.txt", "a+") as f:
|
|
||||||
weights_name = name
|
|
||||||
f.write(weights_name)
|
|
||||||
f.write("\n")
|
|
||||||
if isinstance(value, Tensor):
|
if isinstance(value, Tensor):
|
||||||
print(name, value.asnumpy().shape)
|
|
||||||
if weights_name in weights:
|
if weights_name in weights:
|
||||||
assert weights_name in weights
|
assert weights_name in weights
|
||||||
if isinstance(weights[weights_name], Parameter):
|
if isinstance(weights[weights_name], Parameter):
|
||||||
|
@ -123,7 +103,7 @@ def gnmt_infer(config, dataset):
|
||||||
else:
|
else:
|
||||||
print("weight not found in checkpoint: " + weights_name)
|
print("weight not found in checkpoint: " + weights_name)
|
||||||
param.set_data(zero_weight(value.asnumpy().shape))
|
param.set_data(zero_weight(value.asnumpy().shape))
|
||||||
f.close()
|
|
||||||
print(" | Load weights successfully.")
|
print(" | Load weights successfully.")
|
||||||
tfm_infer = GNMTInferCell(tfm_model)
|
tfm_infer = GNMTInferCell(tfm_model)
|
||||||
model = Model(tfm_infer)
|
model = Model(tfm_infer)
|
||||||
|
|
|
@ -37,7 +37,6 @@ def load_infer_weights(config):
|
||||||
ms_ckpt = load_checkpoint(model_path)
|
ms_ckpt = load_checkpoint(model_path)
|
||||||
is_npz = False
|
is_npz = False
|
||||||
weights = {}
|
weights = {}
|
||||||
with open("variable_after_deal.txt", "w") as f:
|
|
||||||
for param_name in ms_ckpt:
|
for param_name in ms_ckpt:
|
||||||
infer_name = param_name.replace("gnmt.gnmt.", "")
|
infer_name = param_name.replace("gnmt.gnmt.", "")
|
||||||
if infer_name.startswith("embedding_lookup."):
|
if infer_name.startswith("embedding_lookup."):
|
||||||
|
@ -45,17 +44,12 @@ def load_infer_weights(config):
|
||||||
weights[infer_name] = ms_ckpt[param_name]
|
weights[infer_name] = ms_ckpt[param_name]
|
||||||
else:
|
else:
|
||||||
weights[infer_name] = ms_ckpt[param_name].data.asnumpy()
|
weights[infer_name] = ms_ckpt[param_name].data.asnumpy()
|
||||||
f.write(infer_name)
|
|
||||||
f.write("\n")
|
|
||||||
infer_name = "beam_decoder.decoder." + infer_name
|
infer_name = "beam_decoder.decoder." + infer_name
|
||||||
if is_npz:
|
if is_npz:
|
||||||
weights[infer_name] = ms_ckpt[param_name]
|
weights[infer_name] = ms_ckpt[param_name]
|
||||||
else:
|
else:
|
||||||
weights[infer_name] = ms_ckpt[param_name].data.asnumpy()
|
weights[infer_name] = ms_ckpt[param_name].data.asnumpy()
|
||||||
f.write(infer_name)
|
|
||||||
f.write("\n")
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
elif not infer_name.startswith("gnmt_encoder"):
|
elif not infer_name.startswith("gnmt_encoder"):
|
||||||
if infer_name.startswith("gnmt_decoder."):
|
if infer_name.startswith("gnmt_decoder."):
|
||||||
infer_name = infer_name.replace("gnmt_decoder.", "decoder.")
|
infer_name = infer_name.replace("gnmt_decoder.", "decoder.")
|
||||||
|
@ -65,8 +59,4 @@ def load_infer_weights(config):
|
||||||
weights[infer_name] = ms_ckpt[param_name]
|
weights[infer_name] = ms_ckpt[param_name]
|
||||||
else:
|
else:
|
||||||
weights[infer_name] = ms_ckpt[param_name].data.asnumpy()
|
weights[infer_name] = ms_ckpt[param_name].data.asnumpy()
|
||||||
f.write(infer_name)
|
|
||||||
f.write("\n")
|
|
||||||
|
|
||||||
f.close()
|
|
||||||
return weights
|
return weights
|
||||||
|
|
Loading…
Reference in New Issue