!22360 Fix Transformer Mirror Error

Merge pull request !22360 from huangxinjing/fix_transformer_mirror_error
This commit is contained in:
i-robot 2021-08-26 08:16:33 +00:00 committed by Gitee
commit 8d00a8d803
4 changed files with 98 additions and 46 deletions

View File

@ -21,6 +21,7 @@ from mindspore import context
import mindspore.communication.management as D
from mindspore.context import ParallelMode
from mindspore.parallel._utils import _get_parallel_mode
from mindspore import log as logger
__all__ = [
"OpParallelConfig"
@ -56,8 +57,8 @@ class OpParallelConfig(_Config):
def __init__(self, data_parallel=1, model_parallel=1):
Validator.check_positive_int(data_parallel, "data_parallel")
Validator.check_positive_int(model_parallel, "model_parallel")
self._data_parallel = data_parallel
self._model_parallel = model_parallel
self.data_parallel = data_parallel
self.model_parallel = model_parallel
@property
def data_parallel(self):
@ -95,8 +96,8 @@ class _PipeLineConfig(_Config):
def __init__(self, pipeline_stage=1, micro_batch_num=1):
Validator.check_positive_int(pipeline_stage, "pipeline_stage")
Validator.check_positive_int(micro_batch_num, "micro_batch_num")
self._pipeline_stage = pipeline_stage
self._micro_batch_num = micro_batch_num
self.pipeline_stage = pipeline_stage
self.micro_batch_num = micro_batch_num
@property
def pipeline_stage(self):
@ -150,9 +151,10 @@ def _check_config(config):
"should be less than device_num {device_num}")
# the config optimizer_shard is same with context.optimizer_shard
if hasattr(config, "optimizer_shard") and optimizer_shard != config.optimizer_shard:
raise ValueError(f"The optimizer shard {optimizer_shard} in auto_parallel_context is not equal to the"
f"optimizer_shard {config.optimizer_shard} in the config")
if hasattr(config, "optimizer_shard") and optimizer_shard and optimizer_shard != config.optimizer_shard:
logger.warning(f"The optimizer shard {optimizer_shard} in auto_parallel_context is not equal to the"
f" optimizer_shard {config.optimizer_shard} in the OpParallelConfig. Please check the "
f"optimizer_shard to make them consistent.")
# pipeline_stage <= micro_batch_num
if hasattr(config, 'pipeline_stage') and hasattr(config, 'micro_batch_num')\

View File

@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================
"""
NOTE:
Note:
Transformer Networks. This is an experimental interface that is subject to change and/or deletion.
"""
import math
@ -50,7 +50,7 @@ __all__ = [
@constexpr
def _check_input_shape(input_shape, param_name, func_name, target_len):
if len(input_shape) != target_len:
raise ValueError(f"{func_name} {param_name} should be 2d, but got shape {input_shape}")
raise ValueError(f"{func_name} {param_name} should be {target_len}d, but got shape {input_shape}")
return True
@ -107,7 +107,7 @@ class EmbeddingOpParallelConfig(_Config):
def __init__(self, data_parallel=1, model_parallel=1, vocab_emb_dp=True):
self._dp_mp_config = OpParallelConfig(data_parallel=data_parallel, model_parallel=model_parallel)
Validator.check_bool(vocab_emb_dp, "vocab_emb_dp")
self._vocab_emb_dp = vocab_emb_dp
self.vocab_emb_dp = vocab_emb_dp
@property
def data_parallel(self):
@ -180,15 +180,12 @@ class TransformerOpParallelConfig(_Config):
def __init__(self, data_parallel=1, model_parallel=1, pipeline_stage=1, micro_batch_num=1, recompute=False,
optimizer_shard=False, gradient_aggregation_group=4, vocab_emb_dp=True):
Validator.check_bool(recompute, "recompute")
Validator.check_bool(optimizer_shard, "optimizer_shard")
Validator.check_positive_int(gradient_aggregation_group, "gradient_aggregation_group")
self.recompute = recompute
self.optimizer_shard = optimizer_shard
self.gradient_aggregation_group = gradient_aggregation_group
self._embed_dp_mp_config = EmbeddingOpParallelConfig(data_parallel=data_parallel, model_parallel=model_parallel,
vocab_emb_dp=vocab_emb_dp)
self._pp_config = _PipeLineConfig(pipeline_stage=pipeline_stage, micro_batch_num=micro_batch_num)
self._recompute = recompute
self._optimizer_shard = optimizer_shard
self._gradient_aggregation_group = gradient_aggregation_group
@property
def recompute(self):
@ -256,7 +253,7 @@ class TransformerOpParallelConfig(_Config):
def optimizer_shard(self, value):
Validator.check_bool(value, "optimizer_shard")
self._optimizer_shard = value
context.set_auto_parallel_context(optimizer_shard=value)
context.set_auto_parallel_context(enable_parallel_optimizer=value)
@property
def embedding_dp_mp_config(self):
@ -322,6 +319,8 @@ class FeedForward(Cell):
Raises:
ValueError: `hidden_act` is not a string.
ValueError: `parallel_config` is not a subclass of OpParallelConfig.
ValueError: `ffn_hidden_size` is not a multiple of the model parallel way.
ValueError: `hidden_size` is not a multiple of the model parallel way.
Supported Platforms:
``Ascend`` ``GPU``
@ -349,6 +348,13 @@ class FeedForward(Cell):
f"The parallel_config should be a OpParallelConfig type, but found {type(parallel_config)}")
dp = parallel_config.data_parallel
mp = parallel_config.model_parallel
if ffn_hidden_size % mp != 0:
raise ValueError("ffn_hidden_size {ffn_hidden_size} should be a multiple of the model parallel way {mp}")
if hidden_size % mp != 0:
raise ValueError("hidden_size {hidden_size} should be a multiple of the model parallel way {mp}")
if dropout_rate < 0 or dropout_rate >= 1:
raise ValueError("dropout_rate probability should be a number in range [0, 1.0), "
"but got {}".format(dropout_rate))
input_size = hidden_size
output_size = ffn_hidden_size
# Project to ffn_hidden_size
@ -428,7 +434,7 @@ class AttentionMask(Cell):
raise ValueError(
f"The parallel_config should be a OpParallelConfig type, but found {type(parallel_config)}")
self.seq_length = seq_length
self.not_equal = P.NotEqual().shard(((parallel_config.data_parallel, 1),))
self.not_equal = P.NotEqual().shard(((parallel_config.data_parallel, 1), ()))
self.reshape = P.Reshape()
self.mul = P.BatchMatMul().shard(
((parallel_config.data_parallel, 1, 1), (parallel_config.data_parallel, 1, 1)))
@ -492,6 +498,7 @@ class VocabEmbedding(Cell):
Supported Platforms:
``Ascend`` ``GPU``
Examples:
>>> model = VocabEmbedding(vocab_size=30, embedding_size=30)
>>> tensor = Tensor(np.ones((20, 15)), dtype.int32)
@ -612,7 +619,17 @@ class MultiHeadAttention(Cell):
_check_config(parallel_config)
self.src_seq_length = src_seq_length
self.tgt_seq_length = tgt_seq_length
self.hidden_size = hidden_size
self.batch_size = batch_size
Validator.check_positive_int(num_heads, "num_heads")
if hidden_dropout_rate < 0 or hidden_dropout_rate >= 1:
raise ValueError("hidden_dropout_rate probability should be a number in range [0, 1.0), "
"but got {}".format(hidden_dropout_rate))
if attention_dropout_rate < 0 or attention_dropout_rate >= 1:
raise ValueError("attention_dropout_rate probability should be a number in range [0, 1.0), "
"but got {}".format(attention_dropout_rate))
if hidden_size % num_heads != 0:
raise ValueError(f"The hidden size {hidden_size} should be a multiple of num_heads {num_heads}")
if num_heads % parallel_config.model_parallel != 0:
raise ValueError(f"The number of heads {num_heads} must be a "
f"multiple of parallel_config.model_parallel {parallel_config.model_parallel}.")
@ -719,7 +736,8 @@ class MultiHeadAttention(Cell):
output: Tensor, the output logits of this layer
layer_present: Tensor, the feature map of current layer
"""
self._check_inputs(query_tensor, key_tensor, value_tensor, attention_mask, key_past,
value_past, batch_valid_length)
query_tensor_original_shape = F.shape(query_tensor)
query_tensor = F.reshape(query_tensor, (-1, query_tensor_original_shape[-1]))
@ -795,7 +813,7 @@ class MultiHeadAttention(Cell):
# multi head attention considering attention mask
attention = self._attn(query, key, value, attention_mask)
# [bs, seq_length, embedding_size]
attention_merge = self.merge_heads(attention)
attention_merge = self._merge_heads(attention)
# Output
output = self.projection(attention_merge)
output = self.dropout(output)
@ -804,10 +822,14 @@ class MultiHeadAttention(Cell):
def _check_inputs(self, query_tensor, key_tensor, value_tensor, attention_mask, key_past=None,
value_past=None, batch_valid_length=None):
r"""Check inputs"""
_check_input_shape(F.shape(query_tensor), "query_tensor", self.cls_name, 3)
_check_input_shape(F.shape(key_tensor), "key_tensor", self.cls_name, 3)
_check_input_shape(F.shape(value_tensor), "value_tensor", self.cls_name, 3)
_check_input_shape(F.shape(attention_mask), "attention_mask", self.cls_name, 3)
_check_shape_equal(F.shape(query_tensor), "query_tensor", self.cls_name,
[self.batch_size, self.src_seq_length, self.hidden_size])
_check_shape_equal(F.shape(key_tensor), "key_tensor", self.cls_name,
[self.batch_size, self.tgt_seq_length, self.hidden_size])
_check_shape_equal(F.shape(value_tensor), "value_tensor", self.cls_name,
[self.batch_size, self.tgt_seq_length, self.hidden_size])
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
[self.batch_size, self.src_seq_length, self.tgt_seq_length])
_check_input_dtype(F.dtype(query_tensor), "query_tensor", [mstype.float32, mstype.float16], self.cls_name)
_check_input_dtype(F.dtype(key_tensor), "key_tensor", [mstype.float32, mstype.float16], self.cls_name)
@ -816,23 +838,8 @@ class MultiHeadAttention(Cell):
_check_past_none_input_none(self.use_past, "key_past", self.cls_name, key_past)
_check_past_none_input_none(self.use_past, "value_past", self.cls_name, value_past)
_check_past_none_input_none(self.use_past, "batch_valid_length", self.cls_name, batch_valid_length)
def split_heads(self, x, transpose):
"""
split 3d tensor to 4d and switch certain axes
Inputs:
x: input tensor
transpose: tuple, the transpose sequence
Outputs:
x_transpose: the 4d output
"""
x_size = P.Shape()(x)
new_x_shape = x_size[:-1] + (self.n_head, self.size_per_head)
x = self.reshape(x, new_x_shape)
x_transpose = self.transpose(x, transpose)
return x_transpose
def merge_heads(self, x):
return True
def _merge_heads(self, x):
"""
convert a 4d input to a 3d output
@ -1235,8 +1242,8 @@ class TransformerDecoderLayer(Cell):
self.cross_attention = MultiHeadAttention(hidden_size=hidden_size,
num_heads=num_heads,
batch_size=batch_size,
src_seq_length=src_seq_length,
tgt_seq_length=tgt_seq_length,
src_seq_length=tgt_seq_length,
tgt_seq_length=src_seq_length,
hidden_dropout_rate=hidden_dropout_rate,
attention_dropout_rate=attention_dropout_rate,
softmax_comptue_type=softmax_comptue_type,
@ -1705,9 +1712,9 @@ class Transformer(Cell):
r"""
Transformer module including encoder and decoder. The difference with the original implements is the module use
the residual addition before the layernormalization. And the default hidden act is `gelu`.
The detials can be found in `Attention is all you need<https://arxiv.org/pdf/1706.03762v5.pdf>`_.
The detials can be found in `Attention is all you need <https://arxiv.org/pdf/1706.03762v5.pdf>`_.
NOTE:
Note:
This is an experimental interface that is subject to change and/or deletion.
Args:
@ -1832,7 +1839,7 @@ class Transformer(Cell):
raise ValueError(f"Transformer doest support encoder layer {encoder_layers} and decoder"
f"layer {decoder_layers}, please use TransformerDecoder")
if encoder_layers > 0 and decoder_layers > 0 and use_past is True:
raise ValueError("The transformer with encoder and decoder does not support use_past.")
raise ValueError("The transformer with encoder and decoder does not support use_past=True.")
# The shard setting of Transformer is set within the class StackedTransformer
if not lambda_func:
lambda_func = _get_lambda_func(total_layer=encoder_layers + decoder_layers)

View File

@ -203,6 +203,20 @@ def test_multihead_attention():
_executor.compile(model, from_tensor, to_tensor, to_tensor, attention_mask)
def test_multihead_attention_wrong_batch():
model = MultiHeadAttention(hidden_size=15,
src_seq_length=20,
tgt_seq_length=20,
batch_size=2,
num_heads=3)
from_tensor = Tensor(np.ones((3, 20, 15)), dtype.float32)
to_tensor = Tensor(np.ones((3, 20, 15)), dtype.float16)
attention_mask = Tensor(np.ones((3, 20, 20)), dtype.float16)
with pytest.raises(ValueError):
_executor.compile(model, from_tensor, to_tensor, to_tensor, attention_mask)
def test_feedforward_layer():
model = FeedForward(hidden_size=15,
ffn_hidden_size=30,

View File

@ -212,6 +212,35 @@ def test_pipeline_single_transformer():
model.train(1, dataset, dataset_sink_mode=False)
def test_transformer_wrong_head():
set_auto_parallel_context(device_num=32,
full_batch=True,
pipeline_stages=pipeline_config.pipeline_stage, global_rank=0,
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
error_test_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8, vocab_emb_dp=False)
with pytest.raises(ValueError):
net = Transformer(batch_size=4,
src_seq_length=20,
tgt_seq_length=10,
encoder_layers=2,
decoder_layers=2,
hidden_size=64,
num_heads=7,
ffn_hidden_size=64,
parallel_config=error_test_config)
with pytest.raises(ValueError):
net = Transformer(batch_size=4,
src_seq_length=20,
tgt_seq_length=10,
encoder_layers=2,
decoder_layers=2,
hidden_size=63,
num_heads=7,
ffn_hidden_size=64,
parallel_config=error_test_config)
del net
def test_encoder():
class NetWithLoss(nn.Cell):
def __init__(self, network):