forked from mindspore-Ecosystem/mindspore
commit
ba3aa00e92
|
@ -1,281 +0,0 @@
|
||||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
"""Transformer beam search module."""
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import mindspore.common.dtype as mstype
|
|
||||||
import mindspore.nn as nn
|
|
||||||
from mindspore.ops import operations as P
|
|
||||||
from mindspore.common.tensor import Tensor
|
|
||||||
|
|
||||||
INF = 1. * 1e9
|
|
||||||
|
|
||||||
class LengthPenalty(nn.Cell):
|
|
||||||
"""
|
|
||||||
Normalize scores of translations according to their length.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
weight (float): Weight of length penalty. Default: 1.0.
|
|
||||||
compute_type (:class:`mindspore.dtype`): Compute type in Transformer. Default: mstype.float32.
|
|
||||||
"""
|
|
||||||
def __init__(self,
|
|
||||||
weight=1.0,
|
|
||||||
compute_type=mstype.float32):
|
|
||||||
super(LengthPenalty, self).__init__()
|
|
||||||
self.weight = weight
|
|
||||||
self.add = P.Add()
|
|
||||||
self.pow = P.Pow()
|
|
||||||
self.div = P.RealDiv()
|
|
||||||
self.cast = P.Cast()
|
|
||||||
self.five = Tensor(5.0, mstype.float32)
|
|
||||||
self.six = Tensor(6.0, mstype.float32)
|
|
||||||
|
|
||||||
def construct(self, length_tensor):
|
|
||||||
length_tensor = self.cast(length_tensor, mstype.float32)
|
|
||||||
output = self.add(length_tensor, self.five)
|
|
||||||
output = self.div(output, self.six)
|
|
||||||
output = self.pow(output, self.weight)
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class TileBeam(nn.Cell):
|
|
||||||
"""
|
|
||||||
TileBeam.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
beam_width (int): beam width setting. Default: 4.
|
|
||||||
compute_type (:class:`mindspore.dtype`): Compute type in Transformer. Default: mstype.float32.
|
|
||||||
"""
|
|
||||||
def __init__(self,
|
|
||||||
beam_width,
|
|
||||||
compute_type=mstype.float32):
|
|
||||||
super(TileBeam, self).__init__()
|
|
||||||
self.beam_width = beam_width
|
|
||||||
self.expand = P.ExpandDims()
|
|
||||||
self.tile = P.Tile()
|
|
||||||
self.reshape = P.Reshape()
|
|
||||||
self.shape = P.Shape()
|
|
||||||
|
|
||||||
def construct(self, input_tensor):
|
|
||||||
"""
|
|
||||||
input_tensor: shape [batch, dim1, dim2]
|
|
||||||
output_tensor: shape [batch*beam, dim1, dim2]
|
|
||||||
"""
|
|
||||||
shape = self.shape(input_tensor)
|
|
||||||
input_tensor = self.expand(input_tensor, 1)
|
|
||||||
tile_shape = (1,) + (self.beam_width,)
|
|
||||||
for _ in range(len(shape)-1):
|
|
||||||
tile_shape = tile_shape + (1,)
|
|
||||||
output = self.tile(input_tensor, tile_shape)
|
|
||||||
out_shape = (shape[0]*self.beam_width,) + shape[1:]
|
|
||||||
output = self.reshape(output, out_shape)
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class Mod(nn.Cell):
|
|
||||||
"""
|
|
||||||
Mod function.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
compute_type (:class:`mindspore.dtype`): Compute type in Transformer. Default: mstype.float32.
|
|
||||||
"""
|
|
||||||
def __init__(self,
|
|
||||||
compute_type=mstype.float32):
|
|
||||||
super(Mod, self).__init__()
|
|
||||||
self.compute_type = compute_type
|
|
||||||
self.floor_div = P.FloorDiv()
|
|
||||||
self.sub = P.Sub()
|
|
||||||
self.multiply = P.Mul()
|
|
||||||
|
|
||||||
def construct(self, input_x, input_y):
|
|
||||||
x = self.floor_div(input_x, input_y)
|
|
||||||
x = self.multiply(x, input_y)
|
|
||||||
x = self.sub(input_x, x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class BeamSearchDecoder(nn.Cell):
|
|
||||||
"""
|
|
||||||
Beam search decoder.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
batch_size (int): Batch size of input dataset.
|
|
||||||
seq_length (int): Length of input sequence.
|
|
||||||
vocab_size (int): Size of vocabulary.
|
|
||||||
decoder (:class:`TransformerDecoderStep`): Decoder module.
|
|
||||||
beam_width (int): beam width setting. Default: 4.
|
|
||||||
length_penalty_weight (float): Weight of length penalty. Default: 1.0.
|
|
||||||
max_decode_length (int): max decode length. Default: 128.
|
|
||||||
sos_id (int): Id of sequence start token. Default: 1.
|
|
||||||
eos_id (int): Id of sequence end token. Default: 2.
|
|
||||||
compute_type (:class:`mindspore.dtype`): Compute type in Transformer. Default: mstype.float32.
|
|
||||||
"""
|
|
||||||
def __init__(self,
|
|
||||||
batch_size,
|
|
||||||
seq_length,
|
|
||||||
vocab_size,
|
|
||||||
decoder,
|
|
||||||
beam_width=4,
|
|
||||||
length_penalty_weight=1.0,
|
|
||||||
max_decode_length=128,
|
|
||||||
sos_id=1,
|
|
||||||
eos_id=2,
|
|
||||||
compute_type=mstype.float32):
|
|
||||||
super(BeamSearchDecoder, self).__init__(auto_prefix=False)
|
|
||||||
self.seq_length = seq_length
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
self.beam_width = beam_width
|
|
||||||
self.length_penalty_weight = length_penalty_weight
|
|
||||||
self.max_decode_length = max_decode_length
|
|
||||||
self.decoder = decoder
|
|
||||||
|
|
||||||
self.add = P.Add()
|
|
||||||
self.expand = P.ExpandDims()
|
|
||||||
self.reshape = P.Reshape()
|
|
||||||
self.shape_flat = (-1,)
|
|
||||||
self.shape = P.Shape()
|
|
||||||
|
|
||||||
self.zero_tensor = Tensor(np.zeros([batch_size, beam_width]), mstype.float32)
|
|
||||||
self.ninf_tensor = Tensor(np.full([batch_size, beam_width], -INF), mstype.float32)
|
|
||||||
|
|
||||||
self.select = P.Select()
|
|
||||||
self.flat_shape = (batch_size, beam_width * vocab_size)
|
|
||||||
self.topk = P.TopK(sorted=True)
|
|
||||||
self.floor_div = P.FloorDiv()
|
|
||||||
self.vocab_size_tensor = Tensor(self.vocab_size, mstype.int32)
|
|
||||||
self.real_div = P.RealDiv()
|
|
||||||
self.mod = Mod()
|
|
||||||
self.equal = P.Equal()
|
|
||||||
self.eos_ids = Tensor(np.full([batch_size, beam_width], eos_id), mstype.int32)
|
|
||||||
|
|
||||||
beam_ids = np.tile(np.arange(beam_width).reshape((1, beam_width)), [batch_size, 1])
|
|
||||||
self.beam_ids = Tensor(beam_ids, mstype.int32)
|
|
||||||
batch_ids = np.arange(batch_size*beam_width).reshape((batch_size, beam_width)) // beam_width
|
|
||||||
self.batch_ids = Tensor(batch_ids, mstype.int32)
|
|
||||||
self.concat = P.Concat(axis=-1)
|
|
||||||
self.gather_nd = P.GatherNd()
|
|
||||||
|
|
||||||
self.greater_equal = P.GreaterEqual()
|
|
||||||
self.sub = P.Sub()
|
|
||||||
self.cast = P.Cast()
|
|
||||||
self.zeroslike = P.ZerosLike()
|
|
||||||
|
|
||||||
# init inputs and states
|
|
||||||
self.start_ids = Tensor(np.full([batch_size * beam_width, 1], sos_id), mstype.int32)
|
|
||||||
self.init_seq = Tensor(np.full([batch_size, beam_width, 1], sos_id), mstype.int32)
|
|
||||||
init_scores = np.tile(np.array([[0.] + [-INF]*(beam_width-1)]), [batch_size, 1])
|
|
||||||
self.init_scores = Tensor(init_scores, mstype.float32)
|
|
||||||
self.init_finished = Tensor(np.zeros([batch_size, beam_width], dtype=np.bool))
|
|
||||||
self.init_length = Tensor(np.zeros([batch_size, beam_width], dtype=np.int32))
|
|
||||||
self.length_penalty = LengthPenalty(weight=length_penalty_weight)
|
|
||||||
self.one = Tensor(1, mstype.int32)
|
|
||||||
|
|
||||||
def one_step(self, cur_input_ids, enc_states, enc_attention_mask, state_log_probs,
|
|
||||||
state_seq, state_finished, state_length):
|
|
||||||
"""
|
|
||||||
One step for decode
|
|
||||||
"""
|
|
||||||
log_probs = self.decoder(cur_input_ids, enc_states, enc_attention_mask, self.seq_length)
|
|
||||||
log_probs = self.reshape(log_probs, (self.batch_size, self.beam_width, self.vocab_size))
|
|
||||||
|
|
||||||
# select topk indices
|
|
||||||
total_log_probs = self.add(log_probs, self.expand(state_log_probs, -1))
|
|
||||||
|
|
||||||
# mask finished beams
|
|
||||||
mask_tensor = self.select(state_finished, self.ninf_tensor, self.zero_tensor)
|
|
||||||
total_log_probs = self.add(total_log_probs, self.expand(mask_tensor, -1))
|
|
||||||
|
|
||||||
# reshape scores to [batch, beam*vocab]
|
|
||||||
flat_scores = self.reshape(total_log_probs, self.flat_shape)
|
|
||||||
# select topk
|
|
||||||
topk_scores, topk_indices = self.topk(flat_scores, self.beam_width)
|
|
||||||
|
|
||||||
temp = topk_indices
|
|
||||||
beam_indices = self.zeroslike(topk_indices)
|
|
||||||
for _ in range(self.beam_width - 1):
|
|
||||||
temp = self.sub(temp, self.vocab_size_tensor)
|
|
||||||
res = self.cast(self.greater_equal(temp, 0), mstype.int32)
|
|
||||||
beam_indices = beam_indices + res
|
|
||||||
word_indices = topk_indices - beam_indices * self.vocab_size_tensor
|
|
||||||
#======================================================================
|
|
||||||
|
|
||||||
# mask finished indices
|
|
||||||
beam_indices = self.select(state_finished, self.beam_ids, beam_indices)
|
|
||||||
word_indices = self.select(state_finished, self.eos_ids, word_indices)
|
|
||||||
topk_scores = self.select(state_finished, state_log_probs, topk_scores)
|
|
||||||
|
|
||||||
###### put finished sequences to the end
|
|
||||||
# sort according to scores with -inf for finished beams
|
|
||||||
tmp_log_probs = self.select(
|
|
||||||
self.equal(word_indices, self.eos_ids),
|
|
||||||
self.ninf_tensor,
|
|
||||||
topk_scores)
|
|
||||||
_, tmp_indices = self.topk(tmp_log_probs, self.beam_width)
|
|
||||||
# update
|
|
||||||
tmp_gather_indices = self.concat((self.expand(self.batch_ids, -1), self.expand(tmp_indices, -1)))
|
|
||||||
beam_indices = self.gather_nd(beam_indices, tmp_gather_indices)
|
|
||||||
word_indices = self.gather_nd(word_indices, tmp_gather_indices)
|
|
||||||
topk_scores = self.gather_nd(topk_scores, tmp_gather_indices)
|
|
||||||
|
|
||||||
###### generate new beam_search states
|
|
||||||
# gather indices for selecting alive beams
|
|
||||||
gather_indices = self.concat((self.expand(self.batch_ids, -1), self.expand(beam_indices, -1)))
|
|
||||||
|
|
||||||
# length add 1 if not finished in the previous step
|
|
||||||
length_add = self.add(state_length, self.one)
|
|
||||||
state_length = self.select(state_finished, state_length, length_add)
|
|
||||||
state_length = self.gather_nd(state_length, gather_indices)
|
|
||||||
|
|
||||||
# concat seq
|
|
||||||
seq = self.gather_nd(state_seq, gather_indices)
|
|
||||||
state_seq = self.concat((seq, self.expand(word_indices, -1)))
|
|
||||||
|
|
||||||
# new finished flag and log_probs
|
|
||||||
state_finished = self.equal(word_indices, self.eos_ids)
|
|
||||||
state_log_probs = topk_scores
|
|
||||||
|
|
||||||
###### generate new inputs and decoder states
|
|
||||||
cur_input_ids = self.reshape(state_seq, (self.batch_size*self.beam_width, -1))
|
|
||||||
return cur_input_ids, state_log_probs, state_seq, state_finished, state_length
|
|
||||||
|
|
||||||
def construct(self, enc_states, enc_attention_mask):
|
|
||||||
"""Get beam search result."""
|
|
||||||
cur_input_ids = self.start_ids
|
|
||||||
# beam search states
|
|
||||||
state_log_probs = self.init_scores
|
|
||||||
state_seq = self.init_seq
|
|
||||||
state_finished = self.init_finished
|
|
||||||
state_length = self.init_length
|
|
||||||
|
|
||||||
for _ in range(self.max_decode_length):
|
|
||||||
# run one step decoder to get outputs of the current step
|
|
||||||
# shape [batch*beam, 1, vocab]
|
|
||||||
cur_input_ids, state_log_probs, state_seq, state_finished, state_length = self.one_step(
|
|
||||||
cur_input_ids, enc_states, enc_attention_mask, state_log_probs, state_seq, state_finished, state_length)
|
|
||||||
|
|
||||||
# add length penalty scores
|
|
||||||
penalty_len = self.length_penalty(state_length)
|
|
||||||
# get penalty length
|
|
||||||
log_probs = self.real_div(state_log_probs, penalty_len)
|
|
||||||
|
|
||||||
# sort according to scores
|
|
||||||
_, top_beam_indices = self.topk(log_probs, self.beam_width)
|
|
||||||
gather_indices = self.concat((self.expand(self.batch_ids, -1), self.expand(top_beam_indices, -1)))
|
|
||||||
# sort sequence
|
|
||||||
predicted_ids = self.gather_nd(state_seq, gather_indices)
|
|
||||||
# take the first one
|
|
||||||
predicted_ids = predicted_ids[::, 0:1:1, ::]
|
|
||||||
return predicted_ids
|
|
|
@ -1,86 +0,0 @@
|
||||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
"""Network config setting, will be used in dataset.py, train.py."""
|
|
||||||
|
|
||||||
from easydict import EasyDict as edict
|
|
||||||
import mindspore.common.dtype as mstype
|
|
||||||
from .transformer_model import TransformerConfig
|
|
||||||
cfg = edict({
|
|
||||||
'transformer_network': 'large',
|
|
||||||
'init_loss_scale_value': 1024,
|
|
||||||
'scale_factor': 2,
|
|
||||||
'scale_window': 2000,
|
|
||||||
'optimizer': 'Adam',
|
|
||||||
'optimizer_adam_beta2': 0.997,
|
|
||||||
'lr_schedule': edict({
|
|
||||||
'learning_rate': 2.0,
|
|
||||||
'warmup_steps': 8000,
|
|
||||||
'start_decay_step': 16000,
|
|
||||||
'min_lr': 0.0,
|
|
||||||
}),
|
|
||||||
})
|
|
||||||
'''
|
|
||||||
two kinds of transformer model version
|
|
||||||
'''
|
|
||||||
if cfg.transformer_network == 'large':
|
|
||||||
transformer_net_cfg = TransformerConfig(
|
|
||||||
batch_size=96,
|
|
||||||
seq_length=128,
|
|
||||||
vocab_size=36560,
|
|
||||||
hidden_size=1024,
|
|
||||||
num_hidden_layers=6,
|
|
||||||
num_attention_heads=16,
|
|
||||||
intermediate_size=4096,
|
|
||||||
hidden_act="relu",
|
|
||||||
hidden_dropout_prob=0.2,
|
|
||||||
attention_probs_dropout_prob=0.2,
|
|
||||||
max_position_embeddings=128,
|
|
||||||
initializer_range=0.02,
|
|
||||||
label_smoothing=0.1,
|
|
||||||
dtype=mstype.float32,
|
|
||||||
compute_type=mstype.float16)
|
|
||||||
transformer_net_cfg_gpu = TransformerConfig(
|
|
||||||
batch_size=32,
|
|
||||||
seq_length=128,
|
|
||||||
vocab_size=36560,
|
|
||||||
hidden_size=1024,
|
|
||||||
num_hidden_layers=6,
|
|
||||||
num_attention_heads=16,
|
|
||||||
intermediate_size=4096,
|
|
||||||
hidden_act="relu",
|
|
||||||
hidden_dropout_prob=0.2,
|
|
||||||
attention_probs_dropout_prob=0.2,
|
|
||||||
max_position_embeddings=128,
|
|
||||||
initializer_range=0.02,
|
|
||||||
label_smoothing=0.1,
|
|
||||||
dtype=mstype.float32,
|
|
||||||
compute_type=mstype.float16)
|
|
||||||
if cfg.transformer_network == 'base':
|
|
||||||
transformer_net_cfg = TransformerConfig(
|
|
||||||
batch_size=96,
|
|
||||||
seq_length=128,
|
|
||||||
vocab_size=36560,
|
|
||||||
hidden_size=512,
|
|
||||||
num_hidden_layers=6,
|
|
||||||
num_attention_heads=8,
|
|
||||||
intermediate_size=2048,
|
|
||||||
hidden_act="relu",
|
|
||||||
hidden_dropout_prob=0.2,
|
|
||||||
attention_probs_dropout_prob=0.2,
|
|
||||||
max_position_embeddings=128,
|
|
||||||
initializer_range=0.02,
|
|
||||||
label_smoothing=0.1,
|
|
||||||
dtype=mstype.float32,
|
|
||||||
compute_type=mstype.float16)
|
|
|
@ -1,58 +0,0 @@
|
||||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
"""Data operations, will be used in train.py."""
|
|
||||||
|
|
||||||
import mindspore.common.dtype as mstype
|
|
||||||
import mindspore.dataset as de
|
|
||||||
import mindspore.dataset.transforms.c_transforms as deC
|
|
||||||
from .config import transformer_net_cfg, transformer_net_cfg_gpu
|
|
||||||
de.config.set_seed(1)
|
|
||||||
def create_transformer_dataset(epoch_count=1, rank_size=1, rank_id=0, do_shuffle="true", dataset_path=None,
|
|
||||||
bucket_boundaries=None, device_target="Ascend"):
|
|
||||||
"""create dataset"""
|
|
||||||
def batch_per_bucket(bucket_len, dataset_path):
|
|
||||||
dataset_path = dataset_path + "_" + str(bucket_len) + "_00"
|
|
||||||
ds = de.MindDataset(dataset_path,
|
|
||||||
columns_list=["source_eos_ids", "source_eos_mask",
|
|
||||||
"target_sos_ids", "target_sos_mask",
|
|
||||||
"target_eos_ids", "target_eos_mask"],
|
|
||||||
shuffle=(do_shuffle == "true"), num_shards=rank_size, shard_id=rank_id)
|
|
||||||
type_cast_op = deC.TypeCast(mstype.int32)
|
|
||||||
ds = ds.map(operations=type_cast_op, input_columns="source_eos_ids")
|
|
||||||
ds = ds.map(operations=type_cast_op, input_columns="source_eos_mask")
|
|
||||||
ds = ds.map(operations=type_cast_op, input_columns="target_sos_ids")
|
|
||||||
ds = ds.map(operations=type_cast_op, input_columns="target_sos_mask")
|
|
||||||
ds = ds.map(operations=type_cast_op, input_columns="target_eos_ids")
|
|
||||||
ds = ds.map(operations=type_cast_op, input_columns="target_eos_mask")
|
|
||||||
|
|
||||||
# apply batch operations
|
|
||||||
if device_target == "Ascend":
|
|
||||||
ds = ds.batch(transformer_net_cfg.batch_size, drop_remainder=True)
|
|
||||||
else:
|
|
||||||
ds = ds.batch(transformer_net_cfg_gpu.batch_size, drop_remainder=True)
|
|
||||||
|
|
||||||
ds = ds.repeat(epoch_count)
|
|
||||||
return ds
|
|
||||||
|
|
||||||
for i, _ in enumerate(bucket_boundaries):
|
|
||||||
bucket_len = bucket_boundaries[i]
|
|
||||||
ds_per = batch_per_bucket(bucket_len, dataset_path)
|
|
||||||
if i == 0:
|
|
||||||
ds = ds_per
|
|
||||||
else:
|
|
||||||
ds = ds + ds_per
|
|
||||||
ds = ds.shuffle(ds.get_dataset_size())
|
|
||||||
ds.channel_name = 'transformer'
|
|
||||||
return ds
|
|
|
@ -1,67 +0,0 @@
|
||||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
"""Network evaluation config setting, will be used in eval.py."""
|
|
||||||
|
|
||||||
from easydict import EasyDict as edict
|
|
||||||
import mindspore.common.dtype as mstype
|
|
||||||
from .transformer_model import TransformerConfig
|
|
||||||
|
|
||||||
cfg = edict({
|
|
||||||
'transformer_network': 'large',
|
|
||||||
'data_file': '/your/path/evaluation.mindrecord',
|
|
||||||
'model_file': '/your/path/checkpoint_file',
|
|
||||||
'output_file': '/your/path/output',
|
|
||||||
})
|
|
||||||
'''
|
|
||||||
two kinds of transformer model version
|
|
||||||
'''
|
|
||||||
if cfg.transformer_network == 'large':
|
|
||||||
transformer_net_cfg = TransformerConfig(
|
|
||||||
batch_size=1,
|
|
||||||
seq_length=128,
|
|
||||||
vocab_size=36560,
|
|
||||||
hidden_size=1024,
|
|
||||||
num_hidden_layers=6,
|
|
||||||
num_attention_heads=16,
|
|
||||||
intermediate_size=4096,
|
|
||||||
hidden_act="relu",
|
|
||||||
hidden_dropout_prob=0.0,
|
|
||||||
attention_probs_dropout_prob=0.0,
|
|
||||||
max_position_embeddings=128,
|
|
||||||
label_smoothing=0.1,
|
|
||||||
beam_width=4,
|
|
||||||
max_decode_length=80,
|
|
||||||
length_penalty_weight=1.0,
|
|
||||||
dtype=mstype.float32,
|
|
||||||
compute_type=mstype.float16)
|
|
||||||
if cfg.transformer_network == 'base':
|
|
||||||
transformer_net_cfg = TransformerConfig(
|
|
||||||
batch_size=1,
|
|
||||||
seq_length=128,
|
|
||||||
vocab_size=36560,
|
|
||||||
hidden_size=512,
|
|
||||||
num_hidden_layers=6,
|
|
||||||
num_attention_heads=8,
|
|
||||||
intermediate_size=2048,
|
|
||||||
hidden_act="relu",
|
|
||||||
hidden_dropout_prob=0.0,
|
|
||||||
attention_probs_dropout_prob=0.0,
|
|
||||||
max_position_embeddings=128,
|
|
||||||
label_smoothing=0.1,
|
|
||||||
beam_width=4,
|
|
||||||
max_decode_length=80,
|
|
||||||
length_penalty_weight=1.0,
|
|
||||||
dtype=mstype.float32,
|
|
||||||
compute_type=mstype.float16)
|
|
|
@ -1,52 +0,0 @@
|
||||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
"""Learning rate utilities."""
|
|
||||||
|
|
||||||
def linear_warmup(warmup_steps, current_step):
|
|
||||||
return min([1.0, float(current_step)/float(warmup_steps)])
|
|
||||||
|
|
||||||
def rsqrt_decay(warmup_steps, current_step):
|
|
||||||
return float(max([current_step, warmup_steps])) ** -0.5
|
|
||||||
|
|
||||||
def rsqrt_hidden(hidden_size):
|
|
||||||
return float(hidden_size) ** -0.5
|
|
||||||
|
|
||||||
def create_dynamic_lr(schedule, training_steps, learning_rate, warmup_steps, hidden_size,
|
|
||||||
start_decay_step=0, min_lr=0.):
|
|
||||||
"""
|
|
||||||
Generate dynamic learning rate.
|
|
||||||
"""
|
|
||||||
if start_decay_step < warmup_steps:
|
|
||||||
start_decay_step = warmup_steps
|
|
||||||
lr = []
|
|
||||||
for current_step in range(1, training_steps+1):
|
|
||||||
cur_lr = 1.0
|
|
||||||
for name in schedule.split("*"):
|
|
||||||
if name == "constant":
|
|
||||||
cur_lr *= float(learning_rate)
|
|
||||||
elif name == "rsqrt_hidden":
|
|
||||||
cur_lr *= rsqrt_hidden(hidden_size)
|
|
||||||
elif name == "linear_warmup":
|
|
||||||
cur_lr *= linear_warmup(warmup_steps, current_step)
|
|
||||||
elif name == "rsqrt_decay":
|
|
||||||
cur_lr *= rsqrt_decay(warmup_steps, current_step-start_decay_step+warmup_steps)
|
|
||||||
else:
|
|
||||||
raise ValueError("unknown learning rate schedule")
|
|
||||||
if warmup_steps < current_step < start_decay_step:
|
|
||||||
cur_lr = lr[-1]
|
|
||||||
if current_step > warmup_steps:
|
|
||||||
cur_lr = max([cur_lr, min_lr])
|
|
||||||
lr.append(cur_lr)
|
|
||||||
return lr
|
|
|
@ -1,47 +0,0 @@
|
||||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
"""Convert ids to tokens."""
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import tokenization
|
|
||||||
|
|
||||||
# Explicitly set the encoding
|
|
||||||
sys.stdin = open(sys.stdin.fileno(), mode='r', encoding='utf-8', buffering=True)
|
|
||||||
sys.stdout = open(sys.stdout.fileno(), mode='w', encoding='utf-8', buffering=True)
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="recore nbest with smoothed sentence-level bleu.")
|
|
||||||
parser.add_argument("--vocab_file", type=str, default="", required=True, help="vocab file path.")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
tokenizer = tokenization.WhiteSpaceTokenizer(vocab_file=args.vocab_file)
|
|
||||||
|
|
||||||
for line in sys.stdin:
|
|
||||||
token_ids = [int(x) for x in line.strip().split()]
|
|
||||||
tokens = tokenizer.convert_ids_to_tokens(token_ids)
|
|
||||||
sent = " ".join(tokens)
|
|
||||||
sent = sent.split("<s>")[-1]
|
|
||||||
sent = sent.split("</s>")[0]
|
|
||||||
print(sent.strip())
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
|
@ -1,158 +0,0 @@
|
||||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
"""Tokenization utilities."""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
import collections
|
|
||||||
import unicodedata
|
|
||||||
|
|
||||||
def convert_to_printable(text):
|
|
||||||
"""
|
|
||||||
Converts `text` to a printable coding format.
|
|
||||||
"""
|
|
||||||
if sys.version_info[0] == 3:
|
|
||||||
if isinstance(text, str):
|
|
||||||
return text
|
|
||||||
if isinstance(text, bytes):
|
|
||||||
return text.decode("utf-8", "ignore")
|
|
||||||
raise ValueError("Only support type `str` or `bytes`, while text type is `%s`" % (type(text)))
|
|
||||||
if sys.version_info[0] == 2:
|
|
||||||
if isinstance(text, str):
|
|
||||||
return text
|
|
||||||
if isinstance(text, unicode):
|
|
||||||
return text.encode("utf-8")
|
|
||||||
raise ValueError("Only support type `str` or `unicode`, while text type is `%s`" % (type(text)))
|
|
||||||
raise ValueError("Only supported when running on Python2 or Python3.")
|
|
||||||
|
|
||||||
|
|
||||||
def convert_to_unicode(text):
|
|
||||||
"""
|
|
||||||
Converts `text` to Unicode format.
|
|
||||||
"""
|
|
||||||
if sys.version_info[0] == 3:
|
|
||||||
if isinstance(text, str):
|
|
||||||
return text
|
|
||||||
if isinstance(text, bytes):
|
|
||||||
return text.decode("utf-8", "ignore")
|
|
||||||
raise ValueError("Only support type `str` or `bytes`, while text type is `%s`" % (type(text)))
|
|
||||||
if sys.version_info[0] == 2:
|
|
||||||
if isinstance(text, str):
|
|
||||||
return text.decode("utf-8", "ignore")
|
|
||||||
if isinstance(text, unicode):
|
|
||||||
return text
|
|
||||||
raise ValueError("Only support type `str` or `unicode`, while text type is `%s`" % (type(text)))
|
|
||||||
raise ValueError("Only supported when running on Python2 or Python3.")
|
|
||||||
|
|
||||||
|
|
||||||
def load_vocab_file(vocab_file):
|
|
||||||
"""
|
|
||||||
Loads a vocabulary file and turns into a {token:id} dictionary.
|
|
||||||
"""
|
|
||||||
vocab_dict = collections.OrderedDict()
|
|
||||||
index = 0
|
|
||||||
with open(vocab_file, "r") as vocab:
|
|
||||||
while True:
|
|
||||||
token = convert_to_unicode(vocab.readline())
|
|
||||||
if not token:
|
|
||||||
break
|
|
||||||
token = token.strip()
|
|
||||||
vocab_dict[token] = index
|
|
||||||
index += 1
|
|
||||||
return vocab_dict
|
|
||||||
|
|
||||||
|
|
||||||
def convert_by_vocab_dict(vocab_dict, items):
|
|
||||||
"""
|
|
||||||
Converts a sequence of [tokens|ids] according to the vocab dict.
|
|
||||||
"""
|
|
||||||
output = []
|
|
||||||
for item in items:
|
|
||||||
if item in vocab_dict:
|
|
||||||
output.append(vocab_dict[item])
|
|
||||||
else:
|
|
||||||
output.append(vocab_dict["<unk>"])
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class WhiteSpaceTokenizer():
|
|
||||||
"""
|
|
||||||
Whitespace tokenizer.
|
|
||||||
"""
|
|
||||||
def __init__(self, vocab_file):
|
|
||||||
self.vocab_dict = load_vocab_file(vocab_file)
|
|
||||||
self.inv_vocab_dict = {index: token for token, index in self.vocab_dict.items()}
|
|
||||||
|
|
||||||
def _is_whitespace_char(self, char):
|
|
||||||
"""
|
|
||||||
Checks if it is a whitespace character(regard "\t", "\n", "\r" as whitespace here).
|
|
||||||
"""
|
|
||||||
if char in (" ", "\t", "\n", "\r"):
|
|
||||||
return True
|
|
||||||
uni = unicodedata.category(char)
|
|
||||||
if uni == "Zs":
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _is_control_char(self, char):
|
|
||||||
"""
|
|
||||||
Checks if it is a control character.
|
|
||||||
"""
|
|
||||||
if char in ("\t", "\n", "\r"):
|
|
||||||
return False
|
|
||||||
uni = unicodedata.category(char)
|
|
||||||
if uni in ("Cc", "Cf"):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _clean_text(self, text):
|
|
||||||
"""
|
|
||||||
Remove invalid characters and cleanup whitespace.
|
|
||||||
"""
|
|
||||||
output = []
|
|
||||||
for char in text:
|
|
||||||
cp = ord(char)
|
|
||||||
if cp == 0 or cp == 0xfffd or self._is_control_char(char):
|
|
||||||
continue
|
|
||||||
if self._is_whitespace_char(char):
|
|
||||||
output.append(" ")
|
|
||||||
else:
|
|
||||||
output.append(char)
|
|
||||||
return "".join(output)
|
|
||||||
|
|
||||||
def _whitespace_tokenize(self, text):
|
|
||||||
"""
|
|
||||||
Clean whitespace and split text into tokens.
|
|
||||||
"""
|
|
||||||
text = text.strip()
|
|
||||||
if not text:
|
|
||||||
tokens = []
|
|
||||||
else:
|
|
||||||
tokens = text.split()
|
|
||||||
return tokens
|
|
||||||
|
|
||||||
def tokenize(self, text):
|
|
||||||
"""
|
|
||||||
Tokenizes text.
|
|
||||||
"""
|
|
||||||
text = convert_to_unicode(text)
|
|
||||||
text = self._clean_text(text)
|
|
||||||
tokens = self._whitespace_tokenize(text)
|
|
||||||
return tokens
|
|
||||||
|
|
||||||
def convert_tokens_to_ids(self, tokens):
|
|
||||||
return convert_by_vocab_dict(self.vocab_dict, tokens)
|
|
||||||
|
|
||||||
def convert_ids_to_tokens(self, ids):
|
|
||||||
return convert_by_vocab_dict(self.inv_vocab_dict, ids)
|
|
|
@ -1,472 +0,0 @@
|
||||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
"""Transformer for training."""
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from mindspore.common.initializer import initializer
|
|
||||||
import mindspore.nn as nn
|
|
||||||
from mindspore.ops import operations as P
|
|
||||||
from mindspore.ops import functional as F
|
|
||||||
from mindspore.ops import composite as C
|
|
||||||
from mindspore.common.tensor import Tensor
|
|
||||||
from mindspore.common.parameter import Parameter
|
|
||||||
from mindspore.common import dtype as mstype
|
|
||||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
|
||||||
from mindspore.communication.management import get_group_size
|
|
||||||
from mindspore.context import ParallelMode
|
|
||||||
from mindspore import context
|
|
||||||
|
|
||||||
from .transformer_model import TransformerModel
|
|
||||||
|
|
||||||
GRADIENT_CLIP_TYPE = 1
|
|
||||||
GRADIENT_CLIP_VALUE = 5.0
|
|
||||||
|
|
||||||
clip_grad = C.MultitypeFuncGraph("clip_grad")
|
|
||||||
|
|
||||||
|
|
||||||
@clip_grad.register("Number", "Number", "Tensor")
|
|
||||||
def _clip_grad(clip_type, clip_value, grad):
|
|
||||||
"""
|
|
||||||
Clip gradients.
|
|
||||||
|
|
||||||
Inputs:
|
|
||||||
clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
|
|
||||||
clip_value (float): Specifies how much to clip.
|
|
||||||
grad (tuple[Tensor]): Gradients.
|
|
||||||
|
|
||||||
Outputs:
|
|
||||||
tuple[Tensor], clipped gradients.
|
|
||||||
"""
|
|
||||||
if clip_type not in (0, 1):
|
|
||||||
return grad
|
|
||||||
dt = F.dtype(grad)
|
|
||||||
if clip_type == 0:
|
|
||||||
new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
|
|
||||||
F.cast(F.tuple_to_array((clip_value,)), dt))
|
|
||||||
else:
|
|
||||||
new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
|
|
||||||
return new_grad
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerTrainingLoss(nn.Cell):
|
|
||||||
"""
|
|
||||||
Provide transformer training loss.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (TransformerConfig): The config of Transformer.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor, total loss.
|
|
||||||
"""
|
|
||||||
def __init__(self, config):
|
|
||||||
super(TransformerTrainingLoss, self).__init__(auto_prefix=False)
|
|
||||||
self.vocab_size = config.vocab_size
|
|
||||||
self.onehot = P.OneHot()
|
|
||||||
self.on_value = Tensor(float(1 - config.label_smoothing), mstype.float32)
|
|
||||||
self.off_value = Tensor(config.label_smoothing / float(self.vocab_size - 1), mstype.float32)
|
|
||||||
self.reduce_sum = P.ReduceSum()
|
|
||||||
self.reduce_mean = P.ReduceMean()
|
|
||||||
self.reshape = P.Reshape()
|
|
||||||
self.last_idx = (-1,)
|
|
||||||
self.flatten = P.Flatten()
|
|
||||||
self.neg = P.Neg()
|
|
||||||
self.cast = P.Cast()
|
|
||||||
self.batch_size = config.batch_size
|
|
||||||
|
|
||||||
def construct(self, prediction_scores, label_ids, label_weights, seq_length):
|
|
||||||
"""Defines the computation performed."""
|
|
||||||
flat_shape = (self.batch_size * seq_length,)
|
|
||||||
label_ids = self.reshape(label_ids, flat_shape)
|
|
||||||
label_weights = self.cast(self.reshape(label_weights, flat_shape), mstype.float32)
|
|
||||||
one_hot_labels = self.onehot(label_ids, self.vocab_size, self.on_value, self.off_value)
|
|
||||||
|
|
||||||
per_example_loss = self.neg(self.reduce_sum(prediction_scores * one_hot_labels, self.last_idx))
|
|
||||||
numerator = self.reduce_sum(label_weights * per_example_loss, ())
|
|
||||||
denominator = self.reduce_sum(label_weights, ()) + \
|
|
||||||
self.cast(F.tuple_to_array((1e-5,)), mstype.float32)
|
|
||||||
loss = numerator / denominator
|
|
||||||
return loss
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerNetworkWithLoss(nn.Cell):
|
|
||||||
"""
|
|
||||||
Provide transformer training loss through network.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (TransformerConfig): The config of Transformer.
|
|
||||||
is_training (bool): Specifies whether to use the training mode.
|
|
||||||
use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings. Default: False.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor, the loss of the network.
|
|
||||||
"""
|
|
||||||
def __init__(self, config, is_training, use_one_hot_embeddings=False):
|
|
||||||
super(TransformerNetworkWithLoss, self).__init__(auto_prefix=False)
|
|
||||||
self.transformer = TransformerModel(config, is_training, use_one_hot_embeddings)
|
|
||||||
self.loss = TransformerTrainingLoss(config)
|
|
||||||
self.cast = P.Cast()
|
|
||||||
self.shape = P.Shape()
|
|
||||||
|
|
||||||
def construct(self,
|
|
||||||
source_ids,
|
|
||||||
source_mask,
|
|
||||||
target_ids,
|
|
||||||
target_mask,
|
|
||||||
label_ids,
|
|
||||||
label_weights):
|
|
||||||
"""Transformer network with loss."""
|
|
||||||
prediction_scores = self.transformer(source_ids, source_mask, target_ids, target_mask)
|
|
||||||
seq_length = self.shape(source_ids)[1]
|
|
||||||
total_loss = self.loss(prediction_scores, label_ids, label_weights, seq_length)
|
|
||||||
return self.cast(total_loss, mstype.float32)
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerTrainOneStepCell(nn.TrainOneStepCell):
|
|
||||||
"""
|
|
||||||
Encapsulation class of transformer network training.
|
|
||||||
|
|
||||||
Append an optimizer to the training network after that the construct
|
|
||||||
function can be called to create the backward graph.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
network (Cell): The training network. Note that loss function should have been added.
|
|
||||||
optimizer (Optimizer): Optimizer for updating the weights.
|
|
||||||
sens (Number): The adjust parameter. Default: 1.0.
|
|
||||||
"""
|
|
||||||
def __init__(self, network, optimizer, sens=1.0):
|
|
||||||
super(TransformerTrainOneStepCell, self).__init__(network, optimizer, sens)
|
|
||||||
|
|
||||||
self.cast = P.Cast()
|
|
||||||
self.hyper_map = C.HyperMap()
|
|
||||||
|
|
||||||
def set_sens(self, value):
|
|
||||||
self.sens = value
|
|
||||||
|
|
||||||
def construct(self,
|
|
||||||
source_eos_ids,
|
|
||||||
source_eos_mask,
|
|
||||||
target_sos_ids,
|
|
||||||
target_sos_mask,
|
|
||||||
target_eos_ids,
|
|
||||||
target_eos_mask,):
|
|
||||||
"""Defines the computation performed."""
|
|
||||||
source_ids = source_eos_ids
|
|
||||||
source_mask = source_eos_mask
|
|
||||||
target_ids = target_sos_ids
|
|
||||||
target_mask = target_sos_mask
|
|
||||||
label_ids = target_eos_ids
|
|
||||||
label_weights = target_eos_mask
|
|
||||||
|
|
||||||
weights = self.weights
|
|
||||||
loss = self.network(source_ids,
|
|
||||||
source_mask,
|
|
||||||
target_ids,
|
|
||||||
target_mask,
|
|
||||||
label_ids,
|
|
||||||
label_weights)
|
|
||||||
grads = self.grad(self.network, weights)(source_ids,
|
|
||||||
source_mask,
|
|
||||||
target_ids,
|
|
||||||
target_mask,
|
|
||||||
label_ids,
|
|
||||||
label_weights,
|
|
||||||
self.cast(F.tuple_to_array((self.sens,)),
|
|
||||||
mstype.float32))
|
|
||||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
|
||||||
# apply grad reducer on grads
|
|
||||||
grads = self.grad_reducer(grads)
|
|
||||||
succ = self.optimizer(grads)
|
|
||||||
return F.depend(loss, succ)
|
|
||||||
|
|
||||||
|
|
||||||
grad_scale = C.MultitypeFuncGraph("grad_scale")
|
|
||||||
reciprocal = P.Reciprocal()
|
|
||||||
|
|
||||||
|
|
||||||
@grad_scale.register("Tensor", "Tensor")
|
|
||||||
def tensor_grad_scale(scale, grad):
|
|
||||||
return grad * F.cast(reciprocal(scale), F.dtype(grad))
|
|
||||||
|
|
||||||
_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
|
|
||||||
grad_overflow = P.FloatStatus()
|
|
||||||
|
|
||||||
@_grad_overflow.register("Tensor")
|
|
||||||
def _tensor_grad_overflow(grad):
|
|
||||||
return grad_overflow(grad)
|
|
||||||
|
|
||||||
class TransformerTrainOneStepWithLossScaleCell(nn.TrainOneStepWithLossScaleCell):
|
|
||||||
"""
|
|
||||||
Encapsulation class of Transformer network training.
|
|
||||||
|
|
||||||
Append an optimizer to the training network after that the construct
|
|
||||||
function can be called to create the backward graph.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
network (Cell): The training network. Note that loss function should have been added.
|
|
||||||
optimizer (Optimizer): Optimizer for updating the weights.
|
|
||||||
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
|
|
||||||
"""
|
|
||||||
def __init__(self, network, optimizer, scale_update_cell=None):
|
|
||||||
super(TransformerTrainOneStepWithLossScaleCell, self).__init__(network, optimizer, scale_update_cell)
|
|
||||||
self.cast = P.Cast()
|
|
||||||
self.degree = 1
|
|
||||||
if self.reducer_flag:
|
|
||||||
self.degree = get_group_size()
|
|
||||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
|
|
||||||
|
|
||||||
self.loss_scale = None
|
|
||||||
self.loss_scaling_manager = scale_update_cell
|
|
||||||
if scale_update_cell:
|
|
||||||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
|
|
||||||
|
|
||||||
def construct(self,
|
|
||||||
source_eos_ids,
|
|
||||||
source_eos_mask,
|
|
||||||
target_sos_ids,
|
|
||||||
target_sos_mask,
|
|
||||||
target_eos_ids,
|
|
||||||
target_eos_mask,
|
|
||||||
sens=None):
|
|
||||||
"""Defines the computation performed."""
|
|
||||||
source_ids = source_eos_ids
|
|
||||||
source_mask = source_eos_mask
|
|
||||||
target_ids = target_sos_ids
|
|
||||||
target_mask = target_sos_mask
|
|
||||||
label_ids = target_eos_ids
|
|
||||||
label_weights = target_eos_mask
|
|
||||||
|
|
||||||
weights = self.weights
|
|
||||||
loss = self.network(source_ids,
|
|
||||||
source_mask,
|
|
||||||
target_ids,
|
|
||||||
target_mask,
|
|
||||||
label_ids,
|
|
||||||
label_weights)
|
|
||||||
if sens is None:
|
|
||||||
scaling_sens = self.loss_scale
|
|
||||||
else:
|
|
||||||
scaling_sens = sens
|
|
||||||
status, scaling_sens = self.start_overflow_check(loss, scaling_sens)
|
|
||||||
grads = self.grad(self.network, weights)(source_ids,
|
|
||||||
source_mask,
|
|
||||||
target_ids,
|
|
||||||
target_mask,
|
|
||||||
label_ids,
|
|
||||||
label_weights,
|
|
||||||
self.cast(scaling_sens,
|
|
||||||
mstype.float32))
|
|
||||||
|
|
||||||
# apply grad reducer on grads
|
|
||||||
grads = self.grad_reducer(grads)
|
|
||||||
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
|
|
||||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
|
||||||
|
|
||||||
cond = self.get_overflow_status(status, grads)
|
|
||||||
overflow = cond
|
|
||||||
if sens is None:
|
|
||||||
overflow = self.loss_scaling_manager(self.loss_scale, cond)
|
|
||||||
if overflow:
|
|
||||||
succ = False
|
|
||||||
else:
|
|
||||||
succ = self.optimizer(grads)
|
|
||||||
ret = (loss, cond, scaling_sens)
|
|
||||||
return F.depend(ret, succ)
|
|
||||||
|
|
||||||
|
|
||||||
cast = P.Cast()
|
|
||||||
add_grads = C.MultitypeFuncGraph("add_grads")
|
|
||||||
|
|
||||||
|
|
||||||
@add_grads.register("Tensor", "Tensor")
|
|
||||||
def _add_grads(accu_grad, grad):
|
|
||||||
return accu_grad + cast(grad, mstype.float32)
|
|
||||||
|
|
||||||
update_accu_grads = C.MultitypeFuncGraph("update_accu_grads")
|
|
||||||
|
|
||||||
@update_accu_grads.register("Tensor", "Tensor")
|
|
||||||
def _update_accu_grads(accu_grad, grad):
|
|
||||||
succ = True
|
|
||||||
return F.depend(succ, F.assign(accu_grad, cast(grad, mstype.float32)))
|
|
||||||
|
|
||||||
accumulate_accu_grads = C.MultitypeFuncGraph("accumulate_accu_grads")
|
|
||||||
|
|
||||||
@accumulate_accu_grads.register("Tensor", "Tensor")
|
|
||||||
def _accumulate_accu_grads(accu_grad, grad):
|
|
||||||
succ = True
|
|
||||||
return F.depend(succ, F.assign_add(accu_grad, cast(grad, mstype.float32)))
|
|
||||||
|
|
||||||
|
|
||||||
zeroslike = P.ZerosLike()
|
|
||||||
reset_accu_grads = C.MultitypeFuncGraph("reset_accu_grads")
|
|
||||||
|
|
||||||
|
|
||||||
@reset_accu_grads.register("Tensor")
|
|
||||||
def _reset_accu_grads(accu_grad):
|
|
||||||
succ = True
|
|
||||||
return F.depend(succ, F.assign(accu_grad, zeroslike(accu_grad)))
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerTrainAccumulationAllReducePostWithLossScaleCell(nn.Cell):
|
|
||||||
"""
|
|
||||||
Encapsulation class of bert network training.
|
|
||||||
|
|
||||||
Append an optimizer to the training network after that the construct
|
|
||||||
function can be called to create the backward graph.
|
|
||||||
|
|
||||||
To mimic higher batch size, gradients are accumulated N times before weight update.
|
|
||||||
|
|
||||||
For distribution mode, allreduce will only be implemented in the weight updated step,
|
|
||||||
i.e. the sub-step after gradients accumulated N times.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
network (Cell): The training network. Note that loss function should have been added.
|
|
||||||
optimizer (Optimizer): Optimizer for updating the weights.
|
|
||||||
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
|
|
||||||
accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size =
|
|
||||||
batch_size * accumulation_steps. Default: 1.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=8, enable_global_norm=False):
|
|
||||||
super(TransformerTrainAccumulationAllReducePostWithLossScaleCell, self).__init__(auto_prefix=False)
|
|
||||||
self.network = network
|
|
||||||
self.network.set_grad()
|
|
||||||
self.weights = optimizer.parameters
|
|
||||||
self.optimizer = optimizer
|
|
||||||
self.accumulation_steps = accumulation_steps
|
|
||||||
self.enable_global_norm = enable_global_norm
|
|
||||||
self.one = Tensor(np.array([1]).astype(np.int32))
|
|
||||||
self.zero = Tensor(np.array([0]).astype(np.int32))
|
|
||||||
self.local_step = Parameter(initializer(0, [1], mstype.int32))
|
|
||||||
self.accu_grads = self.weights.clone(prefix="accu_grads", init='zeros')
|
|
||||||
self.accu_overflow = Parameter(initializer(0, [1], mstype.int32))
|
|
||||||
self.accu_loss = Parameter(initializer(0, [1], mstype.float32))
|
|
||||||
|
|
||||||
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
|
||||||
self.reducer_flag = False
|
|
||||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
|
||||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
|
||||||
self.reducer_flag = True
|
|
||||||
self.grad_reducer = F.identity
|
|
||||||
self.degree = 1
|
|
||||||
if self.reducer_flag:
|
|
||||||
self.degree = get_group_size()
|
|
||||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
|
|
||||||
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
|
||||||
self.overflow_reducer = F.identity
|
|
||||||
if self.is_distributed:
|
|
||||||
self.overflow_reducer = P.AllReduce()
|
|
||||||
self.cast = P.Cast()
|
|
||||||
self.alloc_status = P.NPUAllocFloatStatus()
|
|
||||||
self.get_status = P.NPUGetFloatStatus()
|
|
||||||
self.clear_status = P.NPUClearFloatStatus()
|
|
||||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
|
||||||
self.base = Tensor(1, mstype.float32)
|
|
||||||
self.less_equal = P.LessEqual()
|
|
||||||
self.logical_or = P.LogicalOr()
|
|
||||||
self.not_equal = P.NotEqual()
|
|
||||||
self.select = P.Select()
|
|
||||||
self.reshape = P.Reshape()
|
|
||||||
self.hyper_map = C.HyperMap()
|
|
||||||
self.loss_scale = None
|
|
||||||
self.loss_scaling_manager = scale_update_cell
|
|
||||||
if scale_update_cell:
|
|
||||||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
|
|
||||||
|
|
||||||
def construct(self,
|
|
||||||
source_eos_ids,
|
|
||||||
source_eos_mask,
|
|
||||||
target_sos_ids,
|
|
||||||
target_sos_mask,
|
|
||||||
target_eos_ids,
|
|
||||||
target_eos_mask,
|
|
||||||
sens=None):
|
|
||||||
"""Defines the computation performed."""
|
|
||||||
source_ids = source_eos_ids
|
|
||||||
source_mask = source_eos_mask
|
|
||||||
target_ids = target_sos_ids
|
|
||||||
target_mask = target_sos_mask
|
|
||||||
label_ids = target_eos_ids
|
|
||||||
label_weights = target_eos_mask
|
|
||||||
|
|
||||||
weights = self.weights
|
|
||||||
loss = self.network(source_ids,
|
|
||||||
source_mask,
|
|
||||||
target_ids,
|
|
||||||
target_mask,
|
|
||||||
label_ids,
|
|
||||||
label_weights)
|
|
||||||
if sens is None:
|
|
||||||
scaling_sens = self.loss_scale
|
|
||||||
else:
|
|
||||||
scaling_sens = sens
|
|
||||||
# alloc status and clear should be right before gradoperation
|
|
||||||
init = self.alloc_status()
|
|
||||||
init = F.depend(init, loss)
|
|
||||||
clear_status = self.clear_status(init)
|
|
||||||
scaling_sens = F.depend(scaling_sens, clear_status)
|
|
||||||
# update accumulation parameters
|
|
||||||
is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)
|
|
||||||
self.local_step = self.select(is_accu_step, self.local_step + self.one, self.one)
|
|
||||||
self.accu_loss = self.select(is_accu_step, self.accu_loss + loss, loss)
|
|
||||||
mean_loss = self.accu_loss / self.local_step
|
|
||||||
is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)
|
|
||||||
|
|
||||||
grads = self.grad(self.network, weights)(source_ids,
|
|
||||||
source_mask,
|
|
||||||
target_ids,
|
|
||||||
target_mask,
|
|
||||||
label_ids,
|
|
||||||
label_weights,
|
|
||||||
self.cast(scaling_sens,
|
|
||||||
mstype.float32))
|
|
||||||
|
|
||||||
accu_succ = self.hyper_map(accumulate_accu_grads, self.accu_grads, grads)
|
|
||||||
mean_loss = F.depend(mean_loss, accu_succ)
|
|
||||||
|
|
||||||
init = F.depend(init, mean_loss)
|
|
||||||
get_status = self.get_status(init)
|
|
||||||
init = F.depend(init, get_status)
|
|
||||||
flag_sum = self.reduce_sum(init, (0,))
|
|
||||||
overflow = self.less_equal(self.base, flag_sum)
|
|
||||||
overflow = self.logical_or(self.not_equal(self.accu_overflow, self.zero), overflow)
|
|
||||||
accu_overflow = self.select(overflow, self.one, self.zero)
|
|
||||||
self.accu_overflow = self.select(is_accu_step, accu_overflow, self.zero)
|
|
||||||
|
|
||||||
if is_accu_step:
|
|
||||||
succ = False
|
|
||||||
else:
|
|
||||||
# apply grad reducer on grads
|
|
||||||
grads = self.grad_reducer(self.accu_grads)
|
|
||||||
scaling = scaling_sens * self.degree * self.accumulation_steps
|
|
||||||
grads = self.hyper_map(F.partial(grad_scale, scaling), grads)
|
|
||||||
if self.enable_global_norm:
|
|
||||||
grads = C.clip_by_global_norm(grads, 1.0, None)
|
|
||||||
else:
|
|
||||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
|
||||||
accu_overflow = F.depend(accu_overflow, grads)
|
|
||||||
accu_overflow = self.overflow_reducer(accu_overflow)
|
|
||||||
overflow = self.less_equal(self.base, accu_overflow)
|
|
||||||
accu_succ = self.hyper_map(reset_accu_grads, self.accu_grads)
|
|
||||||
overflow = F.depend(overflow, accu_succ)
|
|
||||||
overflow = self.reshape(overflow, (()))
|
|
||||||
if sens is None:
|
|
||||||
overflow = self.loss_scaling_manager(self.loss_scale, overflow)
|
|
||||||
if overflow:
|
|
||||||
succ = False
|
|
||||||
else:
|
|
||||||
succ = self.optimizer(grads)
|
|
||||||
|
|
||||||
ret = (mean_loss, overflow, scaling_sens)
|
|
||||||
return F.depend(ret, succ)
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,52 +0,0 @@
|
||||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
"""Weight init utilities."""
|
|
||||||
|
|
||||||
import math
|
|
||||||
import numpy as np
|
|
||||||
from mindspore.common.tensor import Tensor
|
|
||||||
|
|
||||||
def _average_units(shape):
|
|
||||||
"""
|
|
||||||
Average shape dim.
|
|
||||||
"""
|
|
||||||
if not shape:
|
|
||||||
return 1.
|
|
||||||
if len(shape) == 1:
|
|
||||||
return float(shape[0])
|
|
||||||
if len(shape) == 2:
|
|
||||||
return float(shape[0] + shape[1]) / 2.
|
|
||||||
raise RuntimeError("not support shape.")
|
|
||||||
|
|
||||||
def weight_variable(shape):
|
|
||||||
scale_shape = shape
|
|
||||||
avg_units = _average_units(scale_shape)
|
|
||||||
scale = 1.0 / max(1., avg_units)
|
|
||||||
limit = math.sqrt(3.0 * scale)
|
|
||||||
values = np.random.uniform(-limit, limit, shape).astype(np.float32)
|
|
||||||
return Tensor(values)
|
|
||||||
|
|
||||||
def one_weight(shape):
|
|
||||||
ones = np.ones(shape).astype(np.float32)
|
|
||||||
return Tensor(ones)
|
|
||||||
|
|
||||||
def zero_weight(shape):
|
|
||||||
zeros = np.zeros(shape).astype(np.float32)
|
|
||||||
return Tensor(zeros)
|
|
||||||
|
|
||||||
def normal_weight(shape, num_units):
|
|
||||||
norm = np.random.normal(0.0, num_units**-0.5, shape).astype(np.float32)
|
|
||||||
return Tensor(norm)
|
|
||||||
|
|
|
@ -27,14 +27,30 @@ from mindspore.train.callback import Callback
|
||||||
import mindspore.dataset as ds
|
import mindspore.dataset as ds
|
||||||
import mindspore.dataset.transforms.c_transforms as deC
|
import mindspore.dataset.transforms.c_transforms as deC
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from src.transformer_model import TransformerConfig
|
from easydict import EasyDict as edict
|
||||||
from src.transformer_for_train import TransformerNetworkWithLoss, TransformerTrainOneStepWithLossScaleCell
|
from model_zoo.official.nlp.transformer.src.transformer_model import TransformerConfig
|
||||||
from src.config import cfg, transformer_net_cfg
|
from model_zoo.official.nlp.transformer.src.transformer_for_train import TransformerNetworkWithLoss, TransformerTrainOneStepWithLossScaleCell
|
||||||
from src.lr_schedule import create_dynamic_lr
|
from model_zoo.official.nlp.transformer.src.lr_schedule import create_dynamic_lr
|
||||||
from tests.st.model_zoo_tests import utils
|
from tests.st.model_zoo_tests import utils
|
||||||
|
|
||||||
|
|
||||||
DATA_DIR = ["/home/workspace/mindspore_dataset/transformer/test-mindrecord"]
|
DATA_DIR = ["/home/workspace/mindspore_dataset/transformer/test-mindrecord"]
|
||||||
|
|
||||||
|
cfg = edict({
|
||||||
|
'transformer_network': 'large',
|
||||||
|
'init_loss_scale_value': 1024,
|
||||||
|
'scale_factor': 2,
|
||||||
|
'scale_window': 2000,
|
||||||
|
'optimizer': 'Adam',
|
||||||
|
'optimizer_adam_beta2': 0.997,
|
||||||
|
'lr_schedule': edict({
|
||||||
|
'learning_rate': 2.0,
|
||||||
|
'warmup_steps': 8000,
|
||||||
|
'start_decay_step': 16000,
|
||||||
|
'min_lr': 0.0,
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
def get_config(version='base', batch_size=1):
|
def get_config(version='base', batch_size=1):
|
||||||
"""get config"""
|
"""get config"""
|
||||||
|
@ -129,7 +145,7 @@ class TimeMonitor(Callback):
|
||||||
self.per_step_mseconds_list.append(epoch_mseconds / self.data_size)
|
self.per_step_mseconds_list.append(epoch_mseconds / self.data_size)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level2
|
@pytest.mark.level0
|
||||||
@pytest.mark.platform_arm_ascend_training
|
@pytest.mark.platform_arm_ascend_training
|
||||||
@pytest.mark.platform_x86_ascend_training
|
@pytest.mark.platform_x86_ascend_training
|
||||||
@pytest.mark.env_onecard
|
@pytest.mark.env_onecard
|
||||||
|
@ -144,7 +160,7 @@ def test_transformer():
|
||||||
batch_size = 96
|
batch_size = 96
|
||||||
epoch_size = 3
|
epoch_size = 3
|
||||||
config = get_config(version=version, batch_size=batch_size)
|
config = get_config(version=version, batch_size=batch_size)
|
||||||
dataset = load_test_data(batch_size=transformer_net_cfg.batch_size, data_file=DATA_DIR)
|
dataset = load_test_data(batch_size=config.batch_size, data_file=DATA_DIR)
|
||||||
|
|
||||||
netwithloss = TransformerNetworkWithLoss(config, True)
|
netwithloss = TransformerNetworkWithLoss(config, True)
|
||||||
|
|
||||||
|
@ -201,7 +217,7 @@ def test_transformer():
|
||||||
assert per_step_mseconds <= expect_per_step_mseconds + 10
|
assert per_step_mseconds <= expect_per_step_mseconds + 10
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level1
|
@pytest.mark.level0
|
||||||
@pytest.mark.platform_arm_ascend_training
|
@pytest.mark.platform_arm_ascend_training
|
||||||
@pytest.mark.platform_x86_ascend_training
|
@pytest.mark.platform_x86_ascend_training
|
||||||
@pytest.mark.env_onecard
|
@pytest.mark.env_onecard
|
||||||
|
|
Loading…
Reference in New Issue