!22360 Fix Transformer Mirror Error

Merge pull request !22360 from huangxinjing/fix_transformer_mirror_error
This commit is contained in:
i-robot 2021-08-26 08:16:33 +00:00 committed by Gitee
commit 8d00a8d803
4 changed files with 98 additions and 46 deletions

View File

@ -21,6 +21,7 @@ from mindspore import context
import mindspore.communication.management as D import mindspore.communication.management as D
from mindspore.context import ParallelMode from mindspore.context import ParallelMode
from mindspore.parallel._utils import _get_parallel_mode from mindspore.parallel._utils import _get_parallel_mode
from mindspore import log as logger
__all__ = [ __all__ = [
"OpParallelConfig" "OpParallelConfig"
@ -56,8 +57,8 @@ class OpParallelConfig(_Config):
def __init__(self, data_parallel=1, model_parallel=1): def __init__(self, data_parallel=1, model_parallel=1):
Validator.check_positive_int(data_parallel, "data_parallel") Validator.check_positive_int(data_parallel, "data_parallel")
Validator.check_positive_int(model_parallel, "model_parallel") Validator.check_positive_int(model_parallel, "model_parallel")
self._data_parallel = data_parallel self.data_parallel = data_parallel
self._model_parallel = model_parallel self.model_parallel = model_parallel
@property @property
def data_parallel(self): def data_parallel(self):
@ -95,8 +96,8 @@ class _PipeLineConfig(_Config):
def __init__(self, pipeline_stage=1, micro_batch_num=1): def __init__(self, pipeline_stage=1, micro_batch_num=1):
Validator.check_positive_int(pipeline_stage, "pipeline_stage") Validator.check_positive_int(pipeline_stage, "pipeline_stage")
Validator.check_positive_int(micro_batch_num, "micro_batch_num") Validator.check_positive_int(micro_batch_num, "micro_batch_num")
self._pipeline_stage = pipeline_stage self.pipeline_stage = pipeline_stage
self._micro_batch_num = micro_batch_num self.micro_batch_num = micro_batch_num
@property @property
def pipeline_stage(self): def pipeline_stage(self):
@ -150,9 +151,10 @@ def _check_config(config):
"should be less than device_num {device_num}") "should be less than device_num {device_num}")
# the config optimizer_shard is same with context.optimizer_shard # the config optimizer_shard is same with context.optimizer_shard
if hasattr(config, "optimizer_shard") and optimizer_shard != config.optimizer_shard: if hasattr(config, "optimizer_shard") and optimizer_shard and optimizer_shard != config.optimizer_shard:
raise ValueError(f"The optimizer shard {optimizer_shard} in auto_parallel_context is not equal to the" logger.warning(f"The optimizer shard {optimizer_shard} in auto_parallel_context is not equal to the"
f"optimizer_shard {config.optimizer_shard} in the config") f" optimizer_shard {config.optimizer_shard} in the OpParallelConfig. Please check the "
f"optimizer_shard to make them consistent.")
# pipeline_stage <= micro_batch_num # pipeline_stage <= micro_batch_num
if hasattr(config, 'pipeline_stage') and hasattr(config, 'micro_batch_num')\ if hasattr(config, 'pipeline_stage') and hasattr(config, 'micro_batch_num')\

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
""" """
NOTE: Note:
Transformer Networks. This is an experimental interface that is subject to change and/or deletion. Transformer Networks. This is an experimental interface that is subject to change and/or deletion.
""" """
import math import math
@ -50,7 +50,7 @@ __all__ = [
@constexpr @constexpr
def _check_input_shape(input_shape, param_name, func_name, target_len): def _check_input_shape(input_shape, param_name, func_name, target_len):
if len(input_shape) != target_len: if len(input_shape) != target_len:
raise ValueError(f"{func_name} {param_name} should be 2d, but got shape {input_shape}") raise ValueError(f"{func_name} {param_name} should be {target_len}d, but got shape {input_shape}")
return True return True
@ -107,7 +107,7 @@ class EmbeddingOpParallelConfig(_Config):
def __init__(self, data_parallel=1, model_parallel=1, vocab_emb_dp=True): def __init__(self, data_parallel=1, model_parallel=1, vocab_emb_dp=True):
self._dp_mp_config = OpParallelConfig(data_parallel=data_parallel, model_parallel=model_parallel) self._dp_mp_config = OpParallelConfig(data_parallel=data_parallel, model_parallel=model_parallel)
Validator.check_bool(vocab_emb_dp, "vocab_emb_dp") Validator.check_bool(vocab_emb_dp, "vocab_emb_dp")
self._vocab_emb_dp = vocab_emb_dp self.vocab_emb_dp = vocab_emb_dp
@property @property
def data_parallel(self): def data_parallel(self):
@ -180,15 +180,12 @@ class TransformerOpParallelConfig(_Config):
def __init__(self, data_parallel=1, model_parallel=1, pipeline_stage=1, micro_batch_num=1, recompute=False, def __init__(self, data_parallel=1, model_parallel=1, pipeline_stage=1, micro_batch_num=1, recompute=False,
optimizer_shard=False, gradient_aggregation_group=4, vocab_emb_dp=True): optimizer_shard=False, gradient_aggregation_group=4, vocab_emb_dp=True):
Validator.check_bool(recompute, "recompute") self.recompute = recompute
Validator.check_bool(optimizer_shard, "optimizer_shard") self.optimizer_shard = optimizer_shard
Validator.check_positive_int(gradient_aggregation_group, "gradient_aggregation_group") self.gradient_aggregation_group = gradient_aggregation_group
self._embed_dp_mp_config = EmbeddingOpParallelConfig(data_parallel=data_parallel, model_parallel=model_parallel, self._embed_dp_mp_config = EmbeddingOpParallelConfig(data_parallel=data_parallel, model_parallel=model_parallel,
vocab_emb_dp=vocab_emb_dp) vocab_emb_dp=vocab_emb_dp)
self._pp_config = _PipeLineConfig(pipeline_stage=pipeline_stage, micro_batch_num=micro_batch_num) self._pp_config = _PipeLineConfig(pipeline_stage=pipeline_stage, micro_batch_num=micro_batch_num)
self._recompute = recompute
self._optimizer_shard = optimizer_shard
self._gradient_aggregation_group = gradient_aggregation_group
@property @property
def recompute(self): def recompute(self):
@ -256,7 +253,7 @@ class TransformerOpParallelConfig(_Config):
def optimizer_shard(self, value): def optimizer_shard(self, value):
Validator.check_bool(value, "optimizer_shard") Validator.check_bool(value, "optimizer_shard")
self._optimizer_shard = value self._optimizer_shard = value
context.set_auto_parallel_context(optimizer_shard=value) context.set_auto_parallel_context(enable_parallel_optimizer=value)
@property @property
def embedding_dp_mp_config(self): def embedding_dp_mp_config(self):
@ -322,6 +319,8 @@ class FeedForward(Cell):
Raises: Raises:
ValueError: `hidden_act` is not a string. ValueError: `hidden_act` is not a string.
ValueError: `parallel_config` is not a subclass of OpParallelConfig. ValueError: `parallel_config` is not a subclass of OpParallelConfig.
ValueError: `ffn_hidden_size` is not a multiple of the model parallel way.
ValueError: `hidden_size` is not a multiple of the model parallel way.
Supported Platforms: Supported Platforms:
``Ascend`` ``GPU`` ``Ascend`` ``GPU``
@ -349,6 +348,13 @@ class FeedForward(Cell):
f"The parallel_config should be a OpParallelConfig type, but found {type(parallel_config)}") f"The parallel_config should be a OpParallelConfig type, but found {type(parallel_config)}")
dp = parallel_config.data_parallel dp = parallel_config.data_parallel
mp = parallel_config.model_parallel mp = parallel_config.model_parallel
if ffn_hidden_size % mp != 0:
raise ValueError("ffn_hidden_size {ffn_hidden_size} should be a multiple of the model parallel way {mp}")
if hidden_size % mp != 0:
raise ValueError("hidden_size {hidden_size} should be a multiple of the model parallel way {mp}")
if dropout_rate < 0 or dropout_rate >= 1:
raise ValueError("dropout_rate probability should be a number in range [0, 1.0), "
"but got {}".format(dropout_rate))
input_size = hidden_size input_size = hidden_size
output_size = ffn_hidden_size output_size = ffn_hidden_size
# Project to ffn_hidden_size # Project to ffn_hidden_size
@ -428,7 +434,7 @@ class AttentionMask(Cell):
raise ValueError( raise ValueError(
f"The parallel_config should be a OpParallelConfig type, but found {type(parallel_config)}") f"The parallel_config should be a OpParallelConfig type, but found {type(parallel_config)}")
self.seq_length = seq_length self.seq_length = seq_length
self.not_equal = P.NotEqual().shard(((parallel_config.data_parallel, 1),)) self.not_equal = P.NotEqual().shard(((parallel_config.data_parallel, 1), ()))
self.reshape = P.Reshape() self.reshape = P.Reshape()
self.mul = P.BatchMatMul().shard( self.mul = P.BatchMatMul().shard(
((parallel_config.data_parallel, 1, 1), (parallel_config.data_parallel, 1, 1))) ((parallel_config.data_parallel, 1, 1), (parallel_config.data_parallel, 1, 1)))
@ -492,6 +498,7 @@ class VocabEmbedding(Cell):
Supported Platforms: Supported Platforms:
``Ascend`` ``GPU`` ``Ascend`` ``GPU``
Examples: Examples:
>>> model = VocabEmbedding(vocab_size=30, embedding_size=30) >>> model = VocabEmbedding(vocab_size=30, embedding_size=30)
>>> tensor = Tensor(np.ones((20, 15)), dtype.int32) >>> tensor = Tensor(np.ones((20, 15)), dtype.int32)
@ -612,7 +619,17 @@ class MultiHeadAttention(Cell):
_check_config(parallel_config) _check_config(parallel_config)
self.src_seq_length = src_seq_length self.src_seq_length = src_seq_length
self.tgt_seq_length = tgt_seq_length self.tgt_seq_length = tgt_seq_length
self.hidden_size = hidden_size
self.batch_size = batch_size
Validator.check_positive_int(num_heads, "num_heads") Validator.check_positive_int(num_heads, "num_heads")
if hidden_dropout_rate < 0 or hidden_dropout_rate >= 1:
raise ValueError("hidden_dropout_rate probability should be a number in range [0, 1.0), "
"but got {}".format(hidden_dropout_rate))
if attention_dropout_rate < 0 or attention_dropout_rate >= 1:
raise ValueError("attention_dropout_rate probability should be a number in range [0, 1.0), "
"but got {}".format(attention_dropout_rate))
if hidden_size % num_heads != 0:
raise ValueError(f"The hidden size {hidden_size} should be a multiple of num_heads {num_heads}")
if num_heads % parallel_config.model_parallel != 0: if num_heads % parallel_config.model_parallel != 0:
raise ValueError(f"The number of heads {num_heads} must be a " raise ValueError(f"The number of heads {num_heads} must be a "
f"multiple of parallel_config.model_parallel {parallel_config.model_parallel}.") f"multiple of parallel_config.model_parallel {parallel_config.model_parallel}.")
@ -719,7 +736,8 @@ class MultiHeadAttention(Cell):
output: Tensor, the output logits of this layer output: Tensor, the output logits of this layer
layer_present: Tensor, the feature map of current layer layer_present: Tensor, the feature map of current layer
""" """
self._check_inputs(query_tensor, key_tensor, value_tensor, attention_mask, key_past,
value_past, batch_valid_length)
query_tensor_original_shape = F.shape(query_tensor) query_tensor_original_shape = F.shape(query_tensor)
query_tensor = F.reshape(query_tensor, (-1, query_tensor_original_shape[-1])) query_tensor = F.reshape(query_tensor, (-1, query_tensor_original_shape[-1]))
@ -795,7 +813,7 @@ class MultiHeadAttention(Cell):
# multi head attention considering attention mask # multi head attention considering attention mask
attention = self._attn(query, key, value, attention_mask) attention = self._attn(query, key, value, attention_mask)
# [bs, seq_length, embedding_size] # [bs, seq_length, embedding_size]
attention_merge = self.merge_heads(attention) attention_merge = self._merge_heads(attention)
# Output # Output
output = self.projection(attention_merge) output = self.projection(attention_merge)
output = self.dropout(output) output = self.dropout(output)
@ -804,10 +822,14 @@ class MultiHeadAttention(Cell):
def _check_inputs(self, query_tensor, key_tensor, value_tensor, attention_mask, key_past=None, def _check_inputs(self, query_tensor, key_tensor, value_tensor, attention_mask, key_past=None,
value_past=None, batch_valid_length=None): value_past=None, batch_valid_length=None):
r"""Check inputs""" r"""Check inputs"""
_check_input_shape(F.shape(query_tensor), "query_tensor", self.cls_name, 3) _check_shape_equal(F.shape(query_tensor), "query_tensor", self.cls_name,
_check_input_shape(F.shape(key_tensor), "key_tensor", self.cls_name, 3) [self.batch_size, self.src_seq_length, self.hidden_size])
_check_input_shape(F.shape(value_tensor), "value_tensor", self.cls_name, 3) _check_shape_equal(F.shape(key_tensor), "key_tensor", self.cls_name,
_check_input_shape(F.shape(attention_mask), "attention_mask", self.cls_name, 3) [self.batch_size, self.tgt_seq_length, self.hidden_size])
_check_shape_equal(F.shape(value_tensor), "value_tensor", self.cls_name,
[self.batch_size, self.tgt_seq_length, self.hidden_size])
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
[self.batch_size, self.src_seq_length, self.tgt_seq_length])
_check_input_dtype(F.dtype(query_tensor), "query_tensor", [mstype.float32, mstype.float16], self.cls_name) _check_input_dtype(F.dtype(query_tensor), "query_tensor", [mstype.float32, mstype.float16], self.cls_name)
_check_input_dtype(F.dtype(key_tensor), "key_tensor", [mstype.float32, mstype.float16], self.cls_name) _check_input_dtype(F.dtype(key_tensor), "key_tensor", [mstype.float32, mstype.float16], self.cls_name)
@ -816,23 +838,8 @@ class MultiHeadAttention(Cell):
_check_past_none_input_none(self.use_past, "key_past", self.cls_name, key_past) _check_past_none_input_none(self.use_past, "key_past", self.cls_name, key_past)
_check_past_none_input_none(self.use_past, "value_past", self.cls_name, value_past) _check_past_none_input_none(self.use_past, "value_past", self.cls_name, value_past)
_check_past_none_input_none(self.use_past, "batch_valid_length", self.cls_name, batch_valid_length) _check_past_none_input_none(self.use_past, "batch_valid_length", self.cls_name, batch_valid_length)
return True
def split_heads(self, x, transpose): def _merge_heads(self, x):
"""
split 3d tensor to 4d and switch certain axes
Inputs:
x: input tensor
transpose: tuple, the transpose sequence
Outputs:
x_transpose: the 4d output
"""
x_size = P.Shape()(x)
new_x_shape = x_size[:-1] + (self.n_head, self.size_per_head)
x = self.reshape(x, new_x_shape)
x_transpose = self.transpose(x, transpose)
return x_transpose
def merge_heads(self, x):
""" """
convert a 4d input to a 3d output convert a 4d input to a 3d output
@ -1235,8 +1242,8 @@ class TransformerDecoderLayer(Cell):
self.cross_attention = MultiHeadAttention(hidden_size=hidden_size, self.cross_attention = MultiHeadAttention(hidden_size=hidden_size,
num_heads=num_heads, num_heads=num_heads,
batch_size=batch_size, batch_size=batch_size,
src_seq_length=src_seq_length, src_seq_length=tgt_seq_length,
tgt_seq_length=tgt_seq_length, tgt_seq_length=src_seq_length,
hidden_dropout_rate=hidden_dropout_rate, hidden_dropout_rate=hidden_dropout_rate,
attention_dropout_rate=attention_dropout_rate, attention_dropout_rate=attention_dropout_rate,
softmax_comptue_type=softmax_comptue_type, softmax_comptue_type=softmax_comptue_type,
@ -1705,9 +1712,9 @@ class Transformer(Cell):
r""" r"""
Transformer module including encoder and decoder. The difference with the original implements is the module use Transformer module including encoder and decoder. The difference with the original implements is the module use
the residual addition before the layernormalization. And the default hidden act is `gelu`. the residual addition before the layernormalization. And the default hidden act is `gelu`.
The detials can be found in `Attention is all you need<https://arxiv.org/pdf/1706.03762v5.pdf>`_. The detials can be found in `Attention is all you need <https://arxiv.org/pdf/1706.03762v5.pdf>`_.
NOTE: Note:
This is an experimental interface that is subject to change and/or deletion. This is an experimental interface that is subject to change and/or deletion.
Args: Args:
@ -1832,7 +1839,7 @@ class Transformer(Cell):
raise ValueError(f"Transformer doest support encoder layer {encoder_layers} and decoder" raise ValueError(f"Transformer doest support encoder layer {encoder_layers} and decoder"
f"layer {decoder_layers}, please use TransformerDecoder") f"layer {decoder_layers}, please use TransformerDecoder")
if encoder_layers > 0 and decoder_layers > 0 and use_past is True: if encoder_layers > 0 and decoder_layers > 0 and use_past is True:
raise ValueError("The transformer with encoder and decoder does not support use_past.") raise ValueError("The transformer with encoder and decoder does not support use_past=True.")
# The shard setting of Transformer is set within the class StackedTransformer # The shard setting of Transformer is set within the class StackedTransformer
if not lambda_func: if not lambda_func:
lambda_func = _get_lambda_func(total_layer=encoder_layers + decoder_layers) lambda_func = _get_lambda_func(total_layer=encoder_layers + decoder_layers)

View File

@ -203,6 +203,20 @@ def test_multihead_attention():
_executor.compile(model, from_tensor, to_tensor, to_tensor, attention_mask) _executor.compile(model, from_tensor, to_tensor, to_tensor, attention_mask)
def test_multihead_attention_wrong_batch():
model = MultiHeadAttention(hidden_size=15,
src_seq_length=20,
tgt_seq_length=20,
batch_size=2,
num_heads=3)
from_tensor = Tensor(np.ones((3, 20, 15)), dtype.float32)
to_tensor = Tensor(np.ones((3, 20, 15)), dtype.float16)
attention_mask = Tensor(np.ones((3, 20, 20)), dtype.float16)
with pytest.raises(ValueError):
_executor.compile(model, from_tensor, to_tensor, to_tensor, attention_mask)
def test_feedforward_layer(): def test_feedforward_layer():
model = FeedForward(hidden_size=15, model = FeedForward(hidden_size=15,
ffn_hidden_size=30, ffn_hidden_size=30,

View File

@ -212,6 +212,35 @@ def test_pipeline_single_transformer():
model.train(1, dataset, dataset_sink_mode=False) model.train(1, dataset, dataset_sink_mode=False)
def test_transformer_wrong_head():
set_auto_parallel_context(device_num=32,
full_batch=True,
pipeline_stages=pipeline_config.pipeline_stage, global_rank=0,
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
error_test_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8, vocab_emb_dp=False)
with pytest.raises(ValueError):
net = Transformer(batch_size=4,
src_seq_length=20,
tgt_seq_length=10,
encoder_layers=2,
decoder_layers=2,
hidden_size=64,
num_heads=7,
ffn_hidden_size=64,
parallel_config=error_test_config)
with pytest.raises(ValueError):
net = Transformer(batch_size=4,
src_seq_length=20,
tgt_seq_length=10,
encoder_layers=2,
decoder_layers=2,
hidden_size=63,
num_heads=7,
ffn_hidden_size=64,
parallel_config=error_test_config)
del net
def test_encoder(): def test_encoder():
class NetWithLoss(nn.Cell): class NetWithLoss(nn.Cell):
def __init__(self, network): def __init__(self, network):