!22360 Fix Transformer Mirror Error
Merge pull request !22360 from huangxinjing/fix_transformer_mirror_error
This commit is contained in:
commit
8d00a8d803
|
@ -21,6 +21,7 @@ from mindspore import context
|
||||||
import mindspore.communication.management as D
|
import mindspore.communication.management as D
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
from mindspore.parallel._utils import _get_parallel_mode
|
from mindspore.parallel._utils import _get_parallel_mode
|
||||||
|
from mindspore import log as logger
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"OpParallelConfig"
|
"OpParallelConfig"
|
||||||
|
@ -56,8 +57,8 @@ class OpParallelConfig(_Config):
|
||||||
def __init__(self, data_parallel=1, model_parallel=1):
|
def __init__(self, data_parallel=1, model_parallel=1):
|
||||||
Validator.check_positive_int(data_parallel, "data_parallel")
|
Validator.check_positive_int(data_parallel, "data_parallel")
|
||||||
Validator.check_positive_int(model_parallel, "model_parallel")
|
Validator.check_positive_int(model_parallel, "model_parallel")
|
||||||
self._data_parallel = data_parallel
|
self.data_parallel = data_parallel
|
||||||
self._model_parallel = model_parallel
|
self.model_parallel = model_parallel
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def data_parallel(self):
|
def data_parallel(self):
|
||||||
|
@ -95,8 +96,8 @@ class _PipeLineConfig(_Config):
|
||||||
def __init__(self, pipeline_stage=1, micro_batch_num=1):
|
def __init__(self, pipeline_stage=1, micro_batch_num=1):
|
||||||
Validator.check_positive_int(pipeline_stage, "pipeline_stage")
|
Validator.check_positive_int(pipeline_stage, "pipeline_stage")
|
||||||
Validator.check_positive_int(micro_batch_num, "micro_batch_num")
|
Validator.check_positive_int(micro_batch_num, "micro_batch_num")
|
||||||
self._pipeline_stage = pipeline_stage
|
self.pipeline_stage = pipeline_stage
|
||||||
self._micro_batch_num = micro_batch_num
|
self.micro_batch_num = micro_batch_num
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pipeline_stage(self):
|
def pipeline_stage(self):
|
||||||
|
@ -150,9 +151,10 @@ def _check_config(config):
|
||||||
"should be less than device_num {device_num}")
|
"should be less than device_num {device_num}")
|
||||||
|
|
||||||
# the config optimizer_shard is same with context.optimizer_shard
|
# the config optimizer_shard is same with context.optimizer_shard
|
||||||
if hasattr(config, "optimizer_shard") and optimizer_shard != config.optimizer_shard:
|
if hasattr(config, "optimizer_shard") and optimizer_shard and optimizer_shard != config.optimizer_shard:
|
||||||
raise ValueError(f"The optimizer shard {optimizer_shard} in auto_parallel_context is not equal to the"
|
logger.warning(f"The optimizer shard {optimizer_shard} in auto_parallel_context is not equal to the"
|
||||||
f"optimizer_shard {config.optimizer_shard} in the config")
|
f" optimizer_shard {config.optimizer_shard} in the OpParallelConfig. Please check the "
|
||||||
|
f"optimizer_shard to make them consistent.")
|
||||||
|
|
||||||
# pipeline_stage <= micro_batch_num
|
# pipeline_stage <= micro_batch_num
|
||||||
if hasattr(config, 'pipeline_stage') and hasattr(config, 'micro_batch_num')\
|
if hasattr(config, 'pipeline_stage') and hasattr(config, 'micro_batch_num')\
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""
|
"""
|
||||||
NOTE:
|
Note:
|
||||||
Transformer Networks. This is an experimental interface that is subject to change and/or deletion.
|
Transformer Networks. This is an experimental interface that is subject to change and/or deletion.
|
||||||
"""
|
"""
|
||||||
import math
|
import math
|
||||||
|
@ -50,7 +50,7 @@ __all__ = [
|
||||||
@constexpr
|
@constexpr
|
||||||
def _check_input_shape(input_shape, param_name, func_name, target_len):
|
def _check_input_shape(input_shape, param_name, func_name, target_len):
|
||||||
if len(input_shape) != target_len:
|
if len(input_shape) != target_len:
|
||||||
raise ValueError(f"{func_name} {param_name} should be 2d, but got shape {input_shape}")
|
raise ValueError(f"{func_name} {param_name} should be {target_len}d, but got shape {input_shape}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@ -107,7 +107,7 @@ class EmbeddingOpParallelConfig(_Config):
|
||||||
def __init__(self, data_parallel=1, model_parallel=1, vocab_emb_dp=True):
|
def __init__(self, data_parallel=1, model_parallel=1, vocab_emb_dp=True):
|
||||||
self._dp_mp_config = OpParallelConfig(data_parallel=data_parallel, model_parallel=model_parallel)
|
self._dp_mp_config = OpParallelConfig(data_parallel=data_parallel, model_parallel=model_parallel)
|
||||||
Validator.check_bool(vocab_emb_dp, "vocab_emb_dp")
|
Validator.check_bool(vocab_emb_dp, "vocab_emb_dp")
|
||||||
self._vocab_emb_dp = vocab_emb_dp
|
self.vocab_emb_dp = vocab_emb_dp
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def data_parallel(self):
|
def data_parallel(self):
|
||||||
|
@ -180,15 +180,12 @@ class TransformerOpParallelConfig(_Config):
|
||||||
|
|
||||||
def __init__(self, data_parallel=1, model_parallel=1, pipeline_stage=1, micro_batch_num=1, recompute=False,
|
def __init__(self, data_parallel=1, model_parallel=1, pipeline_stage=1, micro_batch_num=1, recompute=False,
|
||||||
optimizer_shard=False, gradient_aggregation_group=4, vocab_emb_dp=True):
|
optimizer_shard=False, gradient_aggregation_group=4, vocab_emb_dp=True):
|
||||||
Validator.check_bool(recompute, "recompute")
|
self.recompute = recompute
|
||||||
Validator.check_bool(optimizer_shard, "optimizer_shard")
|
self.optimizer_shard = optimizer_shard
|
||||||
Validator.check_positive_int(gradient_aggregation_group, "gradient_aggregation_group")
|
self.gradient_aggregation_group = gradient_aggregation_group
|
||||||
self._embed_dp_mp_config = EmbeddingOpParallelConfig(data_parallel=data_parallel, model_parallel=model_parallel,
|
self._embed_dp_mp_config = EmbeddingOpParallelConfig(data_parallel=data_parallel, model_parallel=model_parallel,
|
||||||
vocab_emb_dp=vocab_emb_dp)
|
vocab_emb_dp=vocab_emb_dp)
|
||||||
self._pp_config = _PipeLineConfig(pipeline_stage=pipeline_stage, micro_batch_num=micro_batch_num)
|
self._pp_config = _PipeLineConfig(pipeline_stage=pipeline_stage, micro_batch_num=micro_batch_num)
|
||||||
self._recompute = recompute
|
|
||||||
self._optimizer_shard = optimizer_shard
|
|
||||||
self._gradient_aggregation_group = gradient_aggregation_group
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def recompute(self):
|
def recompute(self):
|
||||||
|
@ -256,7 +253,7 @@ class TransformerOpParallelConfig(_Config):
|
||||||
def optimizer_shard(self, value):
|
def optimizer_shard(self, value):
|
||||||
Validator.check_bool(value, "optimizer_shard")
|
Validator.check_bool(value, "optimizer_shard")
|
||||||
self._optimizer_shard = value
|
self._optimizer_shard = value
|
||||||
context.set_auto_parallel_context(optimizer_shard=value)
|
context.set_auto_parallel_context(enable_parallel_optimizer=value)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def embedding_dp_mp_config(self):
|
def embedding_dp_mp_config(self):
|
||||||
|
@ -322,6 +319,8 @@ class FeedForward(Cell):
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: `hidden_act` is not a string.
|
ValueError: `hidden_act` is not a string.
|
||||||
ValueError: `parallel_config` is not a subclass of OpParallelConfig.
|
ValueError: `parallel_config` is not a subclass of OpParallelConfig.
|
||||||
|
ValueError: `ffn_hidden_size` is not a multiple of the model parallel way.
|
||||||
|
ValueError: `hidden_size` is not a multiple of the model parallel way.
|
||||||
|
|
||||||
Supported Platforms:
|
Supported Platforms:
|
||||||
``Ascend`` ``GPU``
|
``Ascend`` ``GPU``
|
||||||
|
@ -349,6 +348,13 @@ class FeedForward(Cell):
|
||||||
f"The parallel_config should be a OpParallelConfig type, but found {type(parallel_config)}")
|
f"The parallel_config should be a OpParallelConfig type, but found {type(parallel_config)}")
|
||||||
dp = parallel_config.data_parallel
|
dp = parallel_config.data_parallel
|
||||||
mp = parallel_config.model_parallel
|
mp = parallel_config.model_parallel
|
||||||
|
if ffn_hidden_size % mp != 0:
|
||||||
|
raise ValueError("ffn_hidden_size {ffn_hidden_size} should be a multiple of the model parallel way {mp}")
|
||||||
|
if hidden_size % mp != 0:
|
||||||
|
raise ValueError("hidden_size {hidden_size} should be a multiple of the model parallel way {mp}")
|
||||||
|
if dropout_rate < 0 or dropout_rate >= 1:
|
||||||
|
raise ValueError("dropout_rate probability should be a number in range [0, 1.0), "
|
||||||
|
"but got {}".format(dropout_rate))
|
||||||
input_size = hidden_size
|
input_size = hidden_size
|
||||||
output_size = ffn_hidden_size
|
output_size = ffn_hidden_size
|
||||||
# Project to ffn_hidden_size
|
# Project to ffn_hidden_size
|
||||||
|
@ -428,7 +434,7 @@ class AttentionMask(Cell):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The parallel_config should be a OpParallelConfig type, but found {type(parallel_config)}")
|
f"The parallel_config should be a OpParallelConfig type, but found {type(parallel_config)}")
|
||||||
self.seq_length = seq_length
|
self.seq_length = seq_length
|
||||||
self.not_equal = P.NotEqual().shard(((parallel_config.data_parallel, 1),))
|
self.not_equal = P.NotEqual().shard(((parallel_config.data_parallel, 1), ()))
|
||||||
self.reshape = P.Reshape()
|
self.reshape = P.Reshape()
|
||||||
self.mul = P.BatchMatMul().shard(
|
self.mul = P.BatchMatMul().shard(
|
||||||
((parallel_config.data_parallel, 1, 1), (parallel_config.data_parallel, 1, 1)))
|
((parallel_config.data_parallel, 1, 1), (parallel_config.data_parallel, 1, 1)))
|
||||||
|
@ -492,6 +498,7 @@ class VocabEmbedding(Cell):
|
||||||
|
|
||||||
Supported Platforms:
|
Supported Platforms:
|
||||||
``Ascend`` ``GPU``
|
``Ascend`` ``GPU``
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> model = VocabEmbedding(vocab_size=30, embedding_size=30)
|
>>> model = VocabEmbedding(vocab_size=30, embedding_size=30)
|
||||||
>>> tensor = Tensor(np.ones((20, 15)), dtype.int32)
|
>>> tensor = Tensor(np.ones((20, 15)), dtype.int32)
|
||||||
|
@ -612,7 +619,17 @@ class MultiHeadAttention(Cell):
|
||||||
_check_config(parallel_config)
|
_check_config(parallel_config)
|
||||||
self.src_seq_length = src_seq_length
|
self.src_seq_length = src_seq_length
|
||||||
self.tgt_seq_length = tgt_seq_length
|
self.tgt_seq_length = tgt_seq_length
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.batch_size = batch_size
|
||||||
Validator.check_positive_int(num_heads, "num_heads")
|
Validator.check_positive_int(num_heads, "num_heads")
|
||||||
|
if hidden_dropout_rate < 0 or hidden_dropout_rate >= 1:
|
||||||
|
raise ValueError("hidden_dropout_rate probability should be a number in range [0, 1.0), "
|
||||||
|
"but got {}".format(hidden_dropout_rate))
|
||||||
|
if attention_dropout_rate < 0 or attention_dropout_rate >= 1:
|
||||||
|
raise ValueError("attention_dropout_rate probability should be a number in range [0, 1.0), "
|
||||||
|
"but got {}".format(attention_dropout_rate))
|
||||||
|
if hidden_size % num_heads != 0:
|
||||||
|
raise ValueError(f"The hidden size {hidden_size} should be a multiple of num_heads {num_heads}")
|
||||||
if num_heads % parallel_config.model_parallel != 0:
|
if num_heads % parallel_config.model_parallel != 0:
|
||||||
raise ValueError(f"The number of heads {num_heads} must be a "
|
raise ValueError(f"The number of heads {num_heads} must be a "
|
||||||
f"multiple of parallel_config.model_parallel {parallel_config.model_parallel}.")
|
f"multiple of parallel_config.model_parallel {parallel_config.model_parallel}.")
|
||||||
|
@ -719,7 +736,8 @@ class MultiHeadAttention(Cell):
|
||||||
output: Tensor, the output logits of this layer
|
output: Tensor, the output logits of this layer
|
||||||
layer_present: Tensor, the feature map of current layer
|
layer_present: Tensor, the feature map of current layer
|
||||||
"""
|
"""
|
||||||
|
self._check_inputs(query_tensor, key_tensor, value_tensor, attention_mask, key_past,
|
||||||
|
value_past, batch_valid_length)
|
||||||
query_tensor_original_shape = F.shape(query_tensor)
|
query_tensor_original_shape = F.shape(query_tensor)
|
||||||
query_tensor = F.reshape(query_tensor, (-1, query_tensor_original_shape[-1]))
|
query_tensor = F.reshape(query_tensor, (-1, query_tensor_original_shape[-1]))
|
||||||
|
|
||||||
|
@ -795,7 +813,7 @@ class MultiHeadAttention(Cell):
|
||||||
# multi head attention considering attention mask
|
# multi head attention considering attention mask
|
||||||
attention = self._attn(query, key, value, attention_mask)
|
attention = self._attn(query, key, value, attention_mask)
|
||||||
# [bs, seq_length, embedding_size]
|
# [bs, seq_length, embedding_size]
|
||||||
attention_merge = self.merge_heads(attention)
|
attention_merge = self._merge_heads(attention)
|
||||||
# Output
|
# Output
|
||||||
output = self.projection(attention_merge)
|
output = self.projection(attention_merge)
|
||||||
output = self.dropout(output)
|
output = self.dropout(output)
|
||||||
|
@ -804,10 +822,14 @@ class MultiHeadAttention(Cell):
|
||||||
def _check_inputs(self, query_tensor, key_tensor, value_tensor, attention_mask, key_past=None,
|
def _check_inputs(self, query_tensor, key_tensor, value_tensor, attention_mask, key_past=None,
|
||||||
value_past=None, batch_valid_length=None):
|
value_past=None, batch_valid_length=None):
|
||||||
r"""Check inputs"""
|
r"""Check inputs"""
|
||||||
_check_input_shape(F.shape(query_tensor), "query_tensor", self.cls_name, 3)
|
_check_shape_equal(F.shape(query_tensor), "query_tensor", self.cls_name,
|
||||||
_check_input_shape(F.shape(key_tensor), "key_tensor", self.cls_name, 3)
|
[self.batch_size, self.src_seq_length, self.hidden_size])
|
||||||
_check_input_shape(F.shape(value_tensor), "value_tensor", self.cls_name, 3)
|
_check_shape_equal(F.shape(key_tensor), "key_tensor", self.cls_name,
|
||||||
_check_input_shape(F.shape(attention_mask), "attention_mask", self.cls_name, 3)
|
[self.batch_size, self.tgt_seq_length, self.hidden_size])
|
||||||
|
_check_shape_equal(F.shape(value_tensor), "value_tensor", self.cls_name,
|
||||||
|
[self.batch_size, self.tgt_seq_length, self.hidden_size])
|
||||||
|
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
|
||||||
|
[self.batch_size, self.src_seq_length, self.tgt_seq_length])
|
||||||
|
|
||||||
_check_input_dtype(F.dtype(query_tensor), "query_tensor", [mstype.float32, mstype.float16], self.cls_name)
|
_check_input_dtype(F.dtype(query_tensor), "query_tensor", [mstype.float32, mstype.float16], self.cls_name)
|
||||||
_check_input_dtype(F.dtype(key_tensor), "key_tensor", [mstype.float32, mstype.float16], self.cls_name)
|
_check_input_dtype(F.dtype(key_tensor), "key_tensor", [mstype.float32, mstype.float16], self.cls_name)
|
||||||
|
@ -816,23 +838,8 @@ class MultiHeadAttention(Cell):
|
||||||
_check_past_none_input_none(self.use_past, "key_past", self.cls_name, key_past)
|
_check_past_none_input_none(self.use_past, "key_past", self.cls_name, key_past)
|
||||||
_check_past_none_input_none(self.use_past, "value_past", self.cls_name, value_past)
|
_check_past_none_input_none(self.use_past, "value_past", self.cls_name, value_past)
|
||||||
_check_past_none_input_none(self.use_past, "batch_valid_length", self.cls_name, batch_valid_length)
|
_check_past_none_input_none(self.use_past, "batch_valid_length", self.cls_name, batch_valid_length)
|
||||||
|
return True
|
||||||
def split_heads(self, x, transpose):
|
def _merge_heads(self, x):
|
||||||
"""
|
|
||||||
split 3d tensor to 4d and switch certain axes
|
|
||||||
Inputs:
|
|
||||||
x: input tensor
|
|
||||||
transpose: tuple, the transpose sequence
|
|
||||||
Outputs:
|
|
||||||
x_transpose: the 4d output
|
|
||||||
"""
|
|
||||||
x_size = P.Shape()(x)
|
|
||||||
new_x_shape = x_size[:-1] + (self.n_head, self.size_per_head)
|
|
||||||
x = self.reshape(x, new_x_shape)
|
|
||||||
x_transpose = self.transpose(x, transpose)
|
|
||||||
return x_transpose
|
|
||||||
|
|
||||||
def merge_heads(self, x):
|
|
||||||
"""
|
"""
|
||||||
convert a 4d input to a 3d output
|
convert a 4d input to a 3d output
|
||||||
|
|
||||||
|
@ -1235,8 +1242,8 @@ class TransformerDecoderLayer(Cell):
|
||||||
self.cross_attention = MultiHeadAttention(hidden_size=hidden_size,
|
self.cross_attention = MultiHeadAttention(hidden_size=hidden_size,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
src_seq_length=src_seq_length,
|
src_seq_length=tgt_seq_length,
|
||||||
tgt_seq_length=tgt_seq_length,
|
tgt_seq_length=src_seq_length,
|
||||||
hidden_dropout_rate=hidden_dropout_rate,
|
hidden_dropout_rate=hidden_dropout_rate,
|
||||||
attention_dropout_rate=attention_dropout_rate,
|
attention_dropout_rate=attention_dropout_rate,
|
||||||
softmax_comptue_type=softmax_comptue_type,
|
softmax_comptue_type=softmax_comptue_type,
|
||||||
|
@ -1705,9 +1712,9 @@ class Transformer(Cell):
|
||||||
r"""
|
r"""
|
||||||
Transformer module including encoder and decoder. The difference with the original implements is the module use
|
Transformer module including encoder and decoder. The difference with the original implements is the module use
|
||||||
the residual addition before the layernormalization. And the default hidden act is `gelu`.
|
the residual addition before the layernormalization. And the default hidden act is `gelu`.
|
||||||
The detials can be found in `Attention is all you need<https://arxiv.org/pdf/1706.03762v5.pdf>`_.
|
The detials can be found in `Attention is all you need <https://arxiv.org/pdf/1706.03762v5.pdf>`_.
|
||||||
|
|
||||||
NOTE:
|
Note:
|
||||||
This is an experimental interface that is subject to change and/or deletion.
|
This is an experimental interface that is subject to change and/or deletion.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -1832,7 +1839,7 @@ class Transformer(Cell):
|
||||||
raise ValueError(f"Transformer doest support encoder layer {encoder_layers} and decoder"
|
raise ValueError(f"Transformer doest support encoder layer {encoder_layers} and decoder"
|
||||||
f"layer {decoder_layers}, please use TransformerDecoder")
|
f"layer {decoder_layers}, please use TransformerDecoder")
|
||||||
if encoder_layers > 0 and decoder_layers > 0 and use_past is True:
|
if encoder_layers > 0 and decoder_layers > 0 and use_past is True:
|
||||||
raise ValueError("The transformer with encoder and decoder does not support use_past.")
|
raise ValueError("The transformer with encoder and decoder does not support use_past=True.")
|
||||||
# The shard setting of Transformer is set within the class StackedTransformer
|
# The shard setting of Transformer is set within the class StackedTransformer
|
||||||
if not lambda_func:
|
if not lambda_func:
|
||||||
lambda_func = _get_lambda_func(total_layer=encoder_layers + decoder_layers)
|
lambda_func = _get_lambda_func(total_layer=encoder_layers + decoder_layers)
|
||||||
|
|
|
@ -203,6 +203,20 @@ def test_multihead_attention():
|
||||||
_executor.compile(model, from_tensor, to_tensor, to_tensor, attention_mask)
|
_executor.compile(model, from_tensor, to_tensor, to_tensor, attention_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def test_multihead_attention_wrong_batch():
|
||||||
|
model = MultiHeadAttention(hidden_size=15,
|
||||||
|
src_seq_length=20,
|
||||||
|
tgt_seq_length=20,
|
||||||
|
batch_size=2,
|
||||||
|
num_heads=3)
|
||||||
|
from_tensor = Tensor(np.ones((3, 20, 15)), dtype.float32)
|
||||||
|
to_tensor = Tensor(np.ones((3, 20, 15)), dtype.float16)
|
||||||
|
attention_mask = Tensor(np.ones((3, 20, 20)), dtype.float16)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_executor.compile(model, from_tensor, to_tensor, to_tensor, attention_mask)
|
||||||
|
|
||||||
|
|
||||||
def test_feedforward_layer():
|
def test_feedforward_layer():
|
||||||
model = FeedForward(hidden_size=15,
|
model = FeedForward(hidden_size=15,
|
||||||
ffn_hidden_size=30,
|
ffn_hidden_size=30,
|
||||||
|
|
|
@ -212,6 +212,35 @@ def test_pipeline_single_transformer():
|
||||||
model.train(1, dataset, dataset_sink_mode=False)
|
model.train(1, dataset, dataset_sink_mode=False)
|
||||||
|
|
||||||
|
|
||||||
|
def test_transformer_wrong_head():
|
||||||
|
set_auto_parallel_context(device_num=32,
|
||||||
|
full_batch=True,
|
||||||
|
pipeline_stages=pipeline_config.pipeline_stage, global_rank=0,
|
||||||
|
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
|
||||||
|
error_test_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8, vocab_emb_dp=False)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
net = Transformer(batch_size=4,
|
||||||
|
src_seq_length=20,
|
||||||
|
tgt_seq_length=10,
|
||||||
|
encoder_layers=2,
|
||||||
|
decoder_layers=2,
|
||||||
|
hidden_size=64,
|
||||||
|
num_heads=7,
|
||||||
|
ffn_hidden_size=64,
|
||||||
|
parallel_config=error_test_config)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
net = Transformer(batch_size=4,
|
||||||
|
src_seq_length=20,
|
||||||
|
tgt_seq_length=10,
|
||||||
|
encoder_layers=2,
|
||||||
|
decoder_layers=2,
|
||||||
|
hidden_size=63,
|
||||||
|
num_heads=7,
|
||||||
|
ffn_hidden_size=64,
|
||||||
|
parallel_config=error_test_config)
|
||||||
|
del net
|
||||||
|
|
||||||
def test_encoder():
|
def test_encoder():
|
||||||
class NetWithLoss(nn.Cell):
|
class NetWithLoss(nn.Cell):
|
||||||
def __init__(self, network):
|
def __init__(self, network):
|
||||||
|
|
Loading…
Reference in New Issue