less the interface exposed
This commit is contained in:
parent
71874094c4
commit
62496d75f3
|
@ -21,6 +21,7 @@ from mindspore import context
|
|||
import mindspore.communication.management as D
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.parallel._utils import _get_parallel_mode
|
||||
from mindspore import log as logger
|
||||
|
||||
__all__ = [
|
||||
"OpParallelConfig"
|
||||
|
@ -56,8 +57,8 @@ class OpParallelConfig(_Config):
|
|||
def __init__(self, data_parallel=1, model_parallel=1):
|
||||
Validator.check_positive_int(data_parallel, "data_parallel")
|
||||
Validator.check_positive_int(model_parallel, "model_parallel")
|
||||
self._data_parallel = data_parallel
|
||||
self._model_parallel = model_parallel
|
||||
self.data_parallel = data_parallel
|
||||
self.model_parallel = model_parallel
|
||||
|
||||
@property
|
||||
def data_parallel(self):
|
||||
|
@ -95,8 +96,8 @@ class _PipeLineConfig(_Config):
|
|||
def __init__(self, pipeline_stage=1, micro_batch_num=1):
|
||||
Validator.check_positive_int(pipeline_stage, "pipeline_stage")
|
||||
Validator.check_positive_int(micro_batch_num, "micro_batch_num")
|
||||
self._pipeline_stage = pipeline_stage
|
||||
self._micro_batch_num = micro_batch_num
|
||||
self.pipeline_stage = pipeline_stage
|
||||
self.micro_batch_num = micro_batch_num
|
||||
|
||||
@property
|
||||
def pipeline_stage(self):
|
||||
|
@ -150,9 +151,10 @@ def _check_config(config):
|
|||
"should be less than device_num {device_num}")
|
||||
|
||||
# the config optimizer_shard is same with context.optimizer_shard
|
||||
if hasattr(config, "optimizer_shard") and optimizer_shard != config.optimizer_shard:
|
||||
raise ValueError(f"The optimizer shard {optimizer_shard} in auto_parallel_context is not equal to the"
|
||||
f"optimizer_shard {config.optimizer_shard} in the config")
|
||||
if hasattr(config, "optimizer_shard") and optimizer_shard and optimizer_shard != config.optimizer_shard:
|
||||
logger.warning(f"The optimizer shard {optimizer_shard} in auto_parallel_context is not equal to the"
|
||||
f" optimizer_shard {config.optimizer_shard} in the OpParallelConfig. Please check the "
|
||||
f"optimizer_shard to make them consistent.")
|
||||
|
||||
# pipeline_stage <= micro_batch_num
|
||||
if hasattr(config, 'pipeline_stage') and hasattr(config, 'micro_batch_num')\
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
NOTE:
|
||||
Note:
|
||||
Transformer Networks. This is an experimental interface that is subject to change and/or deletion.
|
||||
"""
|
||||
import math
|
||||
|
@ -50,7 +50,7 @@ __all__ = [
|
|||
@constexpr
|
||||
def _check_input_shape(input_shape, param_name, func_name, target_len):
|
||||
if len(input_shape) != target_len:
|
||||
raise ValueError(f"{func_name} {param_name} should be 2d, but got shape {input_shape}")
|
||||
raise ValueError(f"{func_name} {param_name} should be {target_len}d, but got shape {input_shape}")
|
||||
return True
|
||||
|
||||
|
||||
|
@ -107,7 +107,7 @@ class EmbeddingOpParallelConfig(_Config):
|
|||
def __init__(self, data_parallel=1, model_parallel=1, vocab_emb_dp=True):
|
||||
self._dp_mp_config = OpParallelConfig(data_parallel=data_parallel, model_parallel=model_parallel)
|
||||
Validator.check_bool(vocab_emb_dp, "vocab_emb_dp")
|
||||
self._vocab_emb_dp = vocab_emb_dp
|
||||
self.vocab_emb_dp = vocab_emb_dp
|
||||
|
||||
@property
|
||||
def data_parallel(self):
|
||||
|
@ -180,15 +180,12 @@ class TransformerOpParallelConfig(_Config):
|
|||
|
||||
def __init__(self, data_parallel=1, model_parallel=1, pipeline_stage=1, micro_batch_num=1, recompute=False,
|
||||
optimizer_shard=False, gradient_aggregation_group=4, vocab_emb_dp=True):
|
||||
Validator.check_bool(recompute, "recompute")
|
||||
Validator.check_bool(optimizer_shard, "optimizer_shard")
|
||||
Validator.check_positive_int(gradient_aggregation_group, "gradient_aggregation_group")
|
||||
self.recompute = recompute
|
||||
self.optimizer_shard = optimizer_shard
|
||||
self.gradient_aggregation_group = gradient_aggregation_group
|
||||
self._embed_dp_mp_config = EmbeddingOpParallelConfig(data_parallel=data_parallel, model_parallel=model_parallel,
|
||||
vocab_emb_dp=vocab_emb_dp)
|
||||
self._pp_config = _PipeLineConfig(pipeline_stage=pipeline_stage, micro_batch_num=micro_batch_num)
|
||||
self._recompute = recompute
|
||||
self._optimizer_shard = optimizer_shard
|
||||
self._gradient_aggregation_group = gradient_aggregation_group
|
||||
|
||||
@property
|
||||
def recompute(self):
|
||||
|
@ -256,7 +253,7 @@ class TransformerOpParallelConfig(_Config):
|
|||
def optimizer_shard(self, value):
|
||||
Validator.check_bool(value, "optimizer_shard")
|
||||
self._optimizer_shard = value
|
||||
context.set_auto_parallel_context(optimizer_shard=value)
|
||||
context.set_auto_parallel_context(enable_parallel_optimizer=value)
|
||||
|
||||
@property
|
||||
def embedding_dp_mp_config(self):
|
||||
|
@ -322,6 +319,8 @@ class FeedForward(Cell):
|
|||
Raises:
|
||||
ValueError: `hidden_act` is not a string.
|
||||
ValueError: `parallel_config` is not a subclass of OpParallelConfig.
|
||||
ValueError: `ffn_hidden_size` is not a multiple of the model parallel way.
|
||||
ValueError: `hidden_size` is not a multiple of the model parallel way.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
|
@ -349,6 +348,13 @@ class FeedForward(Cell):
|
|||
f"The parallel_config should be a OpParallelConfig type, but found {type(parallel_config)}")
|
||||
dp = parallel_config.data_parallel
|
||||
mp = parallel_config.model_parallel
|
||||
if ffn_hidden_size % mp != 0:
|
||||
raise ValueError("ffn_hidden_size {ffn_hidden_size} should be a multiple of the model parallel way {mp}")
|
||||
if hidden_size % mp != 0:
|
||||
raise ValueError("hidden_size {hidden_size} should be a multiple of the model parallel way {mp}")
|
||||
if dropout_rate < 0 or dropout_rate >= 1:
|
||||
raise ValueError("dropout_rate probability should be a number in range [0, 1.0), "
|
||||
"but got {}".format(dropout_rate))
|
||||
input_size = hidden_size
|
||||
output_size = ffn_hidden_size
|
||||
# Project to ffn_hidden_size
|
||||
|
@ -428,7 +434,7 @@ class AttentionMask(Cell):
|
|||
raise ValueError(
|
||||
f"The parallel_config should be a OpParallelConfig type, but found {type(parallel_config)}")
|
||||
self.seq_length = seq_length
|
||||
self.not_equal = P.NotEqual().shard(((parallel_config.data_parallel, 1),))
|
||||
self.not_equal = P.NotEqual().shard(((parallel_config.data_parallel, 1), ()))
|
||||
self.reshape = P.Reshape()
|
||||
self.mul = P.BatchMatMul().shard(
|
||||
((parallel_config.data_parallel, 1, 1), (parallel_config.data_parallel, 1, 1)))
|
||||
|
@ -492,6 +498,7 @@ class VocabEmbedding(Cell):
|
|||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> model = VocabEmbedding(vocab_size=30, embedding_size=30)
|
||||
>>> tensor = Tensor(np.ones((20, 15)), dtype.int32)
|
||||
|
@ -612,7 +619,17 @@ class MultiHeadAttention(Cell):
|
|||
_check_config(parallel_config)
|
||||
self.src_seq_length = src_seq_length
|
||||
self.tgt_seq_length = tgt_seq_length
|
||||
self.hidden_size = hidden_size
|
||||
self.batch_size = batch_size
|
||||
Validator.check_positive_int(num_heads, "num_heads")
|
||||
if hidden_dropout_rate < 0 or hidden_dropout_rate >= 1:
|
||||
raise ValueError("hidden_dropout_rate probability should be a number in range [0, 1.0), "
|
||||
"but got {}".format(hidden_dropout_rate))
|
||||
if attention_dropout_rate < 0 or attention_dropout_rate >= 1:
|
||||
raise ValueError("attention_dropout_rate probability should be a number in range [0, 1.0), "
|
||||
"but got {}".format(attention_dropout_rate))
|
||||
if hidden_size % num_heads != 0:
|
||||
raise ValueError(f"The hidden size {hidden_size} should be a multiple of num_heads {num_heads}")
|
||||
if num_heads % parallel_config.model_parallel != 0:
|
||||
raise ValueError(f"The number of heads {num_heads} must be a "
|
||||
f"multiple of parallel_config.model_parallel {parallel_config.model_parallel}.")
|
||||
|
@ -719,7 +736,8 @@ class MultiHeadAttention(Cell):
|
|||
output: Tensor, the output logits of this layer
|
||||
layer_present: Tensor, the feature map of current layer
|
||||
"""
|
||||
|
||||
self._check_inputs(query_tensor, key_tensor, value_tensor, attention_mask, key_past,
|
||||
value_past, batch_valid_length)
|
||||
query_tensor_original_shape = F.shape(query_tensor)
|
||||
query_tensor = F.reshape(query_tensor, (-1, query_tensor_original_shape[-1]))
|
||||
|
||||
|
@ -795,7 +813,7 @@ class MultiHeadAttention(Cell):
|
|||
# multi head attention considering attention mask
|
||||
attention = self._attn(query, key, value, attention_mask)
|
||||
# [bs, seq_length, embedding_size]
|
||||
attention_merge = self.merge_heads(attention)
|
||||
attention_merge = self._merge_heads(attention)
|
||||
# Output
|
||||
output = self.projection(attention_merge)
|
||||
output = self.dropout(output)
|
||||
|
@ -804,10 +822,14 @@ class MultiHeadAttention(Cell):
|
|||
def _check_inputs(self, query_tensor, key_tensor, value_tensor, attention_mask, key_past=None,
|
||||
value_past=None, batch_valid_length=None):
|
||||
r"""Check inputs"""
|
||||
_check_input_shape(F.shape(query_tensor), "query_tensor", self.cls_name, 3)
|
||||
_check_input_shape(F.shape(key_tensor), "key_tensor", self.cls_name, 3)
|
||||
_check_input_shape(F.shape(value_tensor), "value_tensor", self.cls_name, 3)
|
||||
_check_input_shape(F.shape(attention_mask), "attention_mask", self.cls_name, 3)
|
||||
_check_shape_equal(F.shape(query_tensor), "query_tensor", self.cls_name,
|
||||
[self.batch_size, self.src_seq_length, self.hidden_size])
|
||||
_check_shape_equal(F.shape(key_tensor), "key_tensor", self.cls_name,
|
||||
[self.batch_size, self.tgt_seq_length, self.hidden_size])
|
||||
_check_shape_equal(F.shape(value_tensor), "value_tensor", self.cls_name,
|
||||
[self.batch_size, self.tgt_seq_length, self.hidden_size])
|
||||
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
|
||||
[self.batch_size, self.src_seq_length, self.tgt_seq_length])
|
||||
|
||||
_check_input_dtype(F.dtype(query_tensor), "query_tensor", [mstype.float32, mstype.float16], self.cls_name)
|
||||
_check_input_dtype(F.dtype(key_tensor), "key_tensor", [mstype.float32, mstype.float16], self.cls_name)
|
||||
|
@ -816,23 +838,8 @@ class MultiHeadAttention(Cell):
|
|||
_check_past_none_input_none(self.use_past, "key_past", self.cls_name, key_past)
|
||||
_check_past_none_input_none(self.use_past, "value_past", self.cls_name, value_past)
|
||||
_check_past_none_input_none(self.use_past, "batch_valid_length", self.cls_name, batch_valid_length)
|
||||
|
||||
def split_heads(self, x, transpose):
|
||||
"""
|
||||
split 3d tensor to 4d and switch certain axes
|
||||
Inputs:
|
||||
x: input tensor
|
||||
transpose: tuple, the transpose sequence
|
||||
Outputs:
|
||||
x_transpose: the 4d output
|
||||
"""
|
||||
x_size = P.Shape()(x)
|
||||
new_x_shape = x_size[:-1] + (self.n_head, self.size_per_head)
|
||||
x = self.reshape(x, new_x_shape)
|
||||
x_transpose = self.transpose(x, transpose)
|
||||
return x_transpose
|
||||
|
||||
def merge_heads(self, x):
|
||||
return True
|
||||
def _merge_heads(self, x):
|
||||
"""
|
||||
convert a 4d input to a 3d output
|
||||
|
||||
|
@ -1235,8 +1242,8 @@ class TransformerDecoderLayer(Cell):
|
|||
self.cross_attention = MultiHeadAttention(hidden_size=hidden_size,
|
||||
num_heads=num_heads,
|
||||
batch_size=batch_size,
|
||||
src_seq_length=src_seq_length,
|
||||
tgt_seq_length=tgt_seq_length,
|
||||
src_seq_length=tgt_seq_length,
|
||||
tgt_seq_length=src_seq_length,
|
||||
hidden_dropout_rate=hidden_dropout_rate,
|
||||
attention_dropout_rate=attention_dropout_rate,
|
||||
softmax_comptue_type=softmax_comptue_type,
|
||||
|
@ -1705,9 +1712,9 @@ class Transformer(Cell):
|
|||
r"""
|
||||
Transformer module including encoder and decoder. The difference with the original implements is the module use
|
||||
the residual addition before the layernormalization. And the default hidden act is `gelu`.
|
||||
The detials can be found in `Attention is all you need<https://arxiv.org/pdf/1706.03762v5.pdf>`_.
|
||||
The detials can be found in `Attention is all you need <https://arxiv.org/pdf/1706.03762v5.pdf>`_.
|
||||
|
||||
NOTE:
|
||||
Note:
|
||||
This is an experimental interface that is subject to change and/or deletion.
|
||||
|
||||
Args:
|
||||
|
@ -1832,7 +1839,7 @@ class Transformer(Cell):
|
|||
raise ValueError(f"Transformer doest support encoder layer {encoder_layers} and decoder"
|
||||
f"layer {decoder_layers}, please use TransformerDecoder")
|
||||
if encoder_layers > 0 and decoder_layers > 0 and use_past is True:
|
||||
raise ValueError("The transformer with encoder and decoder does not support use_past.")
|
||||
raise ValueError("The transformer with encoder and decoder does not support use_past=True.")
|
||||
# The shard setting of Transformer is set within the class StackedTransformer
|
||||
if not lambda_func:
|
||||
lambda_func = _get_lambda_func(total_layer=encoder_layers + decoder_layers)
|
||||
|
|
|
@ -203,6 +203,20 @@ def test_multihead_attention():
|
|||
_executor.compile(model, from_tensor, to_tensor, to_tensor, attention_mask)
|
||||
|
||||
|
||||
def test_multihead_attention_wrong_batch():
|
||||
model = MultiHeadAttention(hidden_size=15,
|
||||
src_seq_length=20,
|
||||
tgt_seq_length=20,
|
||||
batch_size=2,
|
||||
num_heads=3)
|
||||
from_tensor = Tensor(np.ones((3, 20, 15)), dtype.float32)
|
||||
to_tensor = Tensor(np.ones((3, 20, 15)), dtype.float16)
|
||||
attention_mask = Tensor(np.ones((3, 20, 20)), dtype.float16)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_executor.compile(model, from_tensor, to_tensor, to_tensor, attention_mask)
|
||||
|
||||
|
||||
def test_feedforward_layer():
|
||||
model = FeedForward(hidden_size=15,
|
||||
ffn_hidden_size=30,
|
||||
|
|
|
@ -212,6 +212,35 @@ def test_pipeline_single_transformer():
|
|||
model.train(1, dataset, dataset_sink_mode=False)
|
||||
|
||||
|
||||
def test_transformer_wrong_head():
|
||||
set_auto_parallel_context(device_num=32,
|
||||
full_batch=True,
|
||||
pipeline_stages=pipeline_config.pipeline_stage, global_rank=0,
|
||||
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
|
||||
error_test_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8, vocab_emb_dp=False)
|
||||
with pytest.raises(ValueError):
|
||||
net = Transformer(batch_size=4,
|
||||
src_seq_length=20,
|
||||
tgt_seq_length=10,
|
||||
encoder_layers=2,
|
||||
decoder_layers=2,
|
||||
hidden_size=64,
|
||||
num_heads=7,
|
||||
ffn_hidden_size=64,
|
||||
parallel_config=error_test_config)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
net = Transformer(batch_size=4,
|
||||
src_seq_length=20,
|
||||
tgt_seq_length=10,
|
||||
encoder_layers=2,
|
||||
decoder_layers=2,
|
||||
hidden_size=63,
|
||||
num_heads=7,
|
||||
ffn_hidden_size=64,
|
||||
parallel_config=error_test_config)
|
||||
del net
|
||||
|
||||
def test_encoder():
|
||||
class NetWithLoss(nn.Cell):
|
||||
def __init__(self, network):
|
||||
|
|
Loading…
Reference in New Issue