From 6cea07f749f6496cc428aea0753dbfa23e77e9ed Mon Sep 17 00:00:00 2001 From: huangxinjing Date: Tue, 7 Sep 2021 14:17:33 +0800 Subject: [PATCH] Add args check --- mindspore/parallel/nn/layers.py | 67 +++++- mindspore/parallel/nn/transformer.py | 193 +++++++++++++----- .../nlp/pangu_alpha/src/pangu_alpha.py | 2 +- tests/ut/python/nn/test_transformer.py | 8 +- .../parallel/test_parallel_transformer.py | 58 ++++++ 5 files changed, 261 insertions(+), 67 deletions(-) diff --git a/mindspore/parallel/nn/layers.py b/mindspore/parallel/nn/layers.py index d4a081a0242..764ba491cb4 100644 --- a/mindspore/parallel/nn/layers.py +++ b/mindspore/parallel/nn/layers.py @@ -16,7 +16,8 @@ The basic layer of the Transformer Networks. This is an experimental interface that is subject to change and/or deletion. """ - +from functools import wraps, partial +import inspect import math import numpy as np from mindspore.common.parameter import Parameter @@ -37,6 +38,53 @@ __all__ = [ ] +def _args_type_validator_check(*type_args, **type_kwargs): + """Check whether input data type is correct.""" + + def type_check(func): + sig = inspect.signature(func) + bound_types = sig.bind_partial(*type_args, **type_kwargs).arguments + + @wraps(func) + def wrapper(*args, **kwargs): + nonlocal bound_types + bound_values = sig.bind(*args, **kwargs) + + argument_dict = bound_values.arguments + if "kwargs" in bound_types: + bound_types = bound_types["kwargs"] + if "kwargs" in argument_dict: + argument_dict = argument_dict["kwargs"] + for name, value in argument_dict.items(): + if name in bound_types: + bound_types[name](value, name) + return func(*args, **kwargs) + + return wrapper + + return type_check + + +def _valid_type_checks(types, class_name): + # types should be a list of types, this function check if the type is in the valid dtypes + def validator_check_func(value, name): + # The args of Validator.check_type_name is (arg_name, arg_type, valid_types, prim_name) + # as the input of _args_type_validator_check is fixed, so we need to manually change the input order + partial_check = partial(Validator.check_type_name, valid_types=types, prim_name=class_name) + return partial_check(name, type(value)) + return validator_check_func + + +def _valid_value_checks(types, class_name): + # the value should be a list of types, this function check if the value is in the valid dtypes + def validator_check_func(value, name): + # The args of Validator.check_type_name is (arg_name, arg_type, valid_types, prim_name) + # as the input of _args_type_validator_check is fixed, so we need to manually change the input order + partial_check = partial(Validator.check_type_name, valid_types=types, prim_name=class_name) + return partial_check(name, value) + return validator_check_func + + @constexpr def _check_input_shape(input_shape, param_name, func_name, target_len): if len(input_shape) != target_len: @@ -339,7 +387,13 @@ class FixedSparseAttention(nn.Cell): >>> print(output.shape) (2, 1024, 512) """ - + @_args_type_validator_check(batch_size=Validator.check_positive_int, + num_heads=Validator.check_positive_int, + size_per_head=Validator.check_positive_int, + block_size=Validator.check_positive_int, + seq_length=Validator.check_positive_int, + num_different_global_patterns=Validator.check_positive_int, + parallel_config=_valid_type_checks([OpParallelConfig], "FixedSparseAttention")) def __init__(self, batch_size, num_heads, @@ -349,15 +403,6 @@ class FixedSparseAttention(nn.Cell): num_different_global_patterns=4, parallel_config=default_dpmp_config): super(FixedSparseAttention, self).__init__() - if not isinstance(parallel_config, OpParallelConfig): - raise TypeError( - f"The parallel_config should be a OpParallelConfig type, but found {type(parallel_config)}") - Validator.check_positive_int(batch_size, "batch_size") - Validator.check_positive_int(num_heads, "num_heads") - Validator.check_positive_int(size_per_head, "size_per_head") - Validator.check_positive_int(block_size, "block_size") - Validator.check_positive_int(seq_length, "seq_length") - Validator.check_positive_int(num_different_global_patterns, "num_different_global_patterns") dp, mp = parallel_config.data_parallel, parallel_config.model_parallel if num_heads % mp != 0: raise ValueError(f"The number of heads {num_heads} must be a " diff --git a/mindspore/parallel/nn/transformer.py b/mindspore/parallel/nn/transformer.py index f81e0cf5c8b..10786ed826a 100644 --- a/mindspore/parallel/nn/transformer.py +++ b/mindspore/parallel/nn/transformer.py @@ -29,8 +29,9 @@ from mindspore.ops import functional as F from mindspore.nn.cell import Cell from mindspore._checkparam import Validator from mindspore import log as logger -from .layers import _LayerNorm, _Linear, _check_input_shape,\ - _check_shape_equal, _check_past_none_input_none, _check_input_dtype, _check_input_shape_value +from .layers import _LayerNorm, _Linear, _check_input_shape, \ + _check_shape_equal, _check_past_none_input_none, _check_input_dtype, _check_input_shape_value, \ + _args_type_validator_check, _valid_type_checks, _valid_value_checks from .op_parallel_config import default_dpmp_config, _PipeLineConfig, OpParallelConfig, _Config, _check_config __all__ = [ @@ -292,6 +293,14 @@ class FeedForward(Cell): (2, 20, 15) """ + @_args_type_validator_check(hidden_size=Validator.check_positive_int, + ffn_hidden_size=Validator.check_positive_int, + dropout_rate=Validator.check_non_negative_float, + hidden_act=_valid_type_checks([str], "FeedForward"), + param_init_type=_valid_value_checks([mstype.float32, mstype.float16], + "FeedForward"), + parallel_config=_valid_type_checks([OpParallelConfig], + "FeedForward")) def __init__(self, hidden_size, ffn_hidden_size, dropout_rate, @@ -300,13 +309,6 @@ class FeedForward(Cell): parallel_config=default_dpmp_config): super(FeedForward, self).__init__() _check_config(parallel_config) - Validator.check_positive_int(hidden_size, "hidden_size") - Validator.check_positive_int(ffn_hidden_size, "ffn_hidden_size") - if not isinstance(hidden_act, str): - raise ValueError(f"The hidden_act should be a str type, but found {type(hidden_act)}") - if not isinstance(parallel_config, OpParallelConfig): - raise ValueError( - f"The parallel_config should be a OpParallelConfig type, but found {type(parallel_config)}") dp = parallel_config.data_parallel mp = parallel_config.model_parallel if ffn_hidden_size % mp != 0: @@ -388,12 +390,10 @@ class AttentionMask(Cell): [0, 0, 0, 0]]]) """ + @_args_type_validator_check(seq_length=Validator.check_positive_int, + parallel_config=_valid_type_checks([OpParallelConfig], "AttentionMask")) def __init__(self, seq_length, parallel_config=default_dpmp_config): super(AttentionMask, self).__init__() - Validator.check_positive_int(seq_length, "seq_length") - if not isinstance(parallel_config, OpParallelConfig): - raise ValueError( - f"The parallel_config should be a OpParallelConfig type, but found {type(parallel_config)}") self.seq_length = seq_length self.not_equal = P.NotEqual().shard(((parallel_config.data_parallel, 1), ())) self.reshape = P.Reshape() @@ -470,15 +470,13 @@ class VocabEmbedding(Cell): (30, 30) """ + @_args_type_validator_check(vocab_size=Validator.check_positive_int, + embedding_size=Validator.check_positive_int, + parallel_config=_valid_type_checks([EmbeddingOpParallelConfig], "VocabEmbedding")) def __init__(self, vocab_size, embedding_size, parallel_config=default_embedding_parallel_config, param_init='normal'): super(VocabEmbedding, self).__init__() _check_config(parallel_config) - Validator.check_positive_int(vocab_size, "vocab_size") - Validator.check_positive_int(embedding_size, "embedding_size") - if not isinstance(parallel_config, EmbeddingOpParallelConfig): - raise ValueError(f"The parallel_config should be a VocabEmbedding type, but found {type(parallel_config)}") - self.vocab_size = vocab_size self.embedding_size = embedding_size self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]), @@ -564,6 +562,20 @@ class MultiHeadAttention(Cell): (2, 3, 20, 5) """ + @_args_type_validator_check(batch_size=Validator.check_positive_int, + hidden_size=Validator.check_positive_int, + num_heads=Validator.check_positive_int, + src_seq_length=Validator.check_positive_int, + tgt_seq_length=Validator.check_positive_int, + attention_dropout_rate=Validator.check_non_negative_float, + hidden_dropout_rate=Validator.check_non_negative_float, + softmax_comptue_type=_valid_value_checks([mstype.float32, mstype.float16], + "MultiHeadAttention"), + param_init_type=_valid_value_checks([mstype.float32, mstype.float16], + "MultiHeadAttention"), + parallel_config=_valid_type_checks([OpParallelConfig], + "MultiHeadAttention"), + use_past=Validator.check_bool) def __init__(self, batch_size, src_seq_length, tgt_seq_length, @@ -582,11 +594,6 @@ class MultiHeadAttention(Cell): self.tgt_seq_length = tgt_seq_length self.hidden_size = hidden_size self.batch_size = batch_size - Validator.check_positive_int(num_heads, "num_heads") - Validator.check_positive_int(batch_size, "batch_size") - Validator.check_positive_int(src_seq_length, "src_seq_length") - Validator.check_positive_int(tgt_seq_length, "tgt_seq_length") - Validator.check_positive_int(num_heads, "num_heads") if hidden_dropout_rate < 0 or hidden_dropout_rate >= 1: raise ValueError("hidden_dropout_rate probability should be a number in range [0, 1.0), " "but got {}".format(hidden_dropout_rate)) @@ -787,11 +794,13 @@ class MultiHeadAttention(Cell): _check_input_dtype(F.dtype(query_tensor), "query_tensor", [mstype.float32, mstype.float16], self.cls_name) _check_input_dtype(F.dtype(key_tensor), "key_tensor", [mstype.float32, mstype.float16], self.cls_name) _check_input_dtype(F.dtype(value_tensor), "value_tensor", [mstype.float32, mstype.float16], self.cls_name) + _check_input_dtype(F.dtype(attention_mask), "attention_mask", [mstype.float32, mstype.float16], self.cls_name) _check_past_none_input_none(self.use_past, "key_past", self.cls_name, key_past) _check_past_none_input_none(self.use_past, "value_past", self.cls_name, value_past) _check_past_none_input_none(self.use_past, "batch_valid_length", self.cls_name, batch_valid_length) return True + def _merge_heads(self, x): """ convert a 4d input to a 3d output @@ -932,6 +941,24 @@ class TransformerEncoderLayer(Cell): (2, 2, 16, 4) """ + @_args_type_validator_check(batch_size=Validator.check_positive_int, + hidden_size=Validator.check_positive_int, + num_heads=Validator.check_positive_int, + ffn_hidden_size=Validator.check_positive_int, + seq_length=Validator.check_positive_int, + attention_dropout_rate=Validator.check_non_negative_float, + hidden_dropout_rate=Validator.check_non_negative_float, + hidden_act=_valid_type_checks([str], "TransformerEncoderLayer"), + post_layernorm_residual=Validator.check_bool, + layernorm_compute_type=_valid_value_checks([mstype.float32, mstype.float16], + "TransformerEncoderLayer"), + softmax_comptue_type=_valid_value_checks([mstype.float32, mstype.float16], + "TransformerEncoderLayer"), + param_init_type=_valid_value_checks([mstype.float32, mstype.float16], + "TransformerEncoderLayer"), + parallel_config=_valid_type_checks([OpParallelConfig], + "TransformerEncoderLayer"), + use_past=Validator.check_bool) def __init__(self, batch_size, hidden_size, @@ -953,9 +980,6 @@ class TransformerEncoderLayer(Cell): raise ValueError( f"num heads must be divisibled by the model parallel way {parallel_config.model_parallel}," f"but found {num_heads}") - Validator.check_bool(post_layernorm_residual, "post_layernorm_residual") - if not isinstance(hidden_act, str): - raise ValueError(f"The hidden_act should be a str type, but found {type(hidden_act)}") self.use_past = use_past self.seq_length = seq_length self.hidden_size = hidden_size @@ -986,6 +1010,8 @@ class TransformerEncoderLayer(Cell): self.post_layernorm_residual = post_layernorm_residual self.add = P.TensorAdd().shard(((parallel_config.data_parallel, 1, 1), (parallel_config.data_parallel, 1, 1))) self.dtype = mstype.float16 + self.key_past = None + self.value_past = None if self.use_past: # operator used for state reuse @@ -1058,8 +1084,10 @@ class TransformerEncoderLayer(Cell): r"""Check inputs""" _check_shape_equal(F.shape(x), "x", self.cls_name, [self.batch_size, self.seq_length, self.hidden_size]) + _check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16], self.cls_name) _check_shape_equal(F.shape(input_mask), "input_mask", self.cls_name, [self.batch_size, self.seq_length, self.seq_length]) + _check_input_dtype(F.dtype(input_mask), "input_mask", [mstype.float32, mstype.float16], self.cls_name) _check_past_none_input_none(self.use_past, "init_reset", self.cls_name, init_reset, True) if init_reset is not True: _check_input_dtype(F.dtype(init_reset), "init_reset", [mstype.bool_], self.cls_name) @@ -1143,6 +1171,25 @@ class TransformerDecoderLayer(Cell): (2, 2, 20, 32) """ + @_args_type_validator_check(batch_size=Validator.check_positive_int, + hidden_size=Validator.check_positive_int, + num_heads=Validator.check_positive_int, + ffn_hidden_size=Validator.check_positive_int, + src_seq_length=Validator.check_positive_int, + tgt_seq_length=Validator.check_positive_int, + attention_dropout_rate=Validator.check_non_negative_float, + hidden_dropout_rate=Validator.check_non_negative_float, + hidden_act=_valid_type_checks([str], "TransformerDecoderLayer"), + post_layernorm_residual=Validator.check_bool, + layernorm_compute_type=_valid_value_checks([mstype.float32, mstype.float16], + "TransformerDecoderLayer"), + softmax_comptue_type=_valid_value_checks([mstype.float32, mstype.float16], + "TransformerDecoderLayer"), + param_init_type=_valid_value_checks([mstype.float32, mstype.float16], + "TransformerDecoderLayer"), + parallel_config=_valid_type_checks([OpParallelConfig], + "TransformerDecoderLayer"), + use_past=Validator.check_bool) def __init__(self, hidden_size, ffn_hidden_size, num_heads, @@ -1163,13 +1210,7 @@ class TransformerDecoderLayer(Cell): self.batch_size = batch_size self.use_past = use_past self.softmax_comptue_type = softmax_comptue_type - if num_heads % parallel_config.model_parallel != 0: - raise ValueError( - f"num heads must be divisibled by the model parallel way {parallel_config.model_parallel}," - f"but found {num_heads}") - Validator.check_bool(post_layernorm_residual, "post_layernorm_residual") - if not isinstance(hidden_act, str): - raise ValueError(f"The hidden_act should be a str type, but found {type(hidden_act)}") + self.src_seq_length = src_seq_length self.tgt_seq_length = tgt_seq_length self.use_past = use_past @@ -1217,6 +1258,8 @@ class TransformerDecoderLayer(Cell): self.post_layernorm_residual = post_layernorm_residual self.add = P.TensorAdd().shard(((parallel_config.data_parallel, 1, 1), (parallel_config.data_parallel, 1, 1))) self.dtype = mstype.float16 + self.key_past = None + self.value_past = None if self.use_past: # operator used for state reuse self.reducesum = P.ReduceSum().shard(((1, 1, 1, 1),)) @@ -1308,12 +1351,18 @@ class TransformerDecoderLayer(Cell): [self.batch_size, self.tgt_seq_length, self.hidden_size]) _check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name, [self.batch_size, self.tgt_seq_length, self.tgt_seq_length]) + _check_input_dtype(F.dtype(hidden_states), "hidden_size", [mstype.float32, mstype.float16], self.cls_name) + _check_input_dtype(F.dtype(attention_mask), "attention_mask", [mstype.float32, mstype.float16], self.cls_name) if encoder_output is not None: _check_shape_equal(F.shape(encoder_output), "encoder_output", self.cls_name, [self.batch_size, self.src_seq_length, self.hidden_size]) + _check_input_dtype(F.dtype(encoder_output), "encoder_output", + [mstype.float32, mstype.float16], self.cls_name) if memory_mask is not None: _check_shape_equal(F.shape(memory_mask), "memory_mask", self.cls_name, [self.batch_size, self.tgt_seq_length, self.src_seq_length]) + _check_input_dtype(F.dtype(memory_mask), "memory_mask", + [mstype.float32, mstype.float16], self.cls_name) _check_past_none_input_none(self.use_past, "init_reset", self.cls_name, init_reset, True) if init_reset is not True: @@ -1437,6 +1486,26 @@ class TransformerEncoder(Cell): (2, 2, 16, 4) """ + @_args_type_validator_check(batch_size=Validator.check_positive_int, + hidden_size=Validator.check_positive_int, + num_heads=Validator.check_positive_int, + ffn_hidden_size=Validator.check_positive_int, + seq_length=Validator.check_positive_int, + num_layers=Validator.check_positive_int, + offset=Validator.check_non_negative_int, + attention_dropout_rate=Validator.check_non_negative_float, + hidden_dropout_rate=Validator.check_non_negative_float, + hidden_act=_valid_type_checks([str], "TransformerEncoder"), + post_layernorm_residual=Validator.check_bool, + layernorm_compute_type=_valid_value_checks([mstype.float32, mstype.float16], + "TransformerEncoder"), + softmax_comptue_type=_valid_value_checks([mstype.float32, mstype.float16], + "TransformerEncoder"), + param_init_type=_valid_value_checks([mstype.float32, mstype.float16], + "TransformerEncoder"), + parallel_config=_valid_type_checks([TransformerOpParallelConfig], + "TransformerEncoder"), + use_past=Validator.check_bool) def __init__(self, batch_size, num_layers, @@ -1457,15 +1526,6 @@ class TransformerEncoder(Cell): parallel_config=default_transformer_config): super(TransformerEncoder, self).__init__() _check_config(parallel_config) - Validator.check_positive_int(num_layers, "num_layers") - Validator.check_non_negative_int(offset, "offset") - if num_heads % parallel_config.model_parallel != 0: - raise ValueError( - f"num heads must be divisibled by the model parallel way {parallel_config.model_parallel}," - f"but found {num_heads}") - Validator.check_bool(post_layernorm_residual, "post_layernorm_residual") - if not isinstance(hidden_act, str): - raise ValueError(f"The hidden_act should be a str type, but found {type(hidden_act)}") self.num_layers = num_layers self.blocks = nn.CellList() @@ -1587,6 +1647,27 @@ class TransformerDecoder(Cell): """ + @_args_type_validator_check(batch_size=Validator.check_positive_int, + hidden_size=Validator.check_positive_int, + num_heads=Validator.check_positive_int, + ffn_hidden_size=Validator.check_positive_int, + src_seq_length=Validator.check_positive_int, + num_layers=Validator.check_positive_int, + tgt_seq_length=Validator.check_positive_int, + offset=Validator.check_non_negative_int, + attention_dropout_rate=Validator.check_non_negative_float, + hidden_dropout_rate=Validator.check_non_negative_float, + hidden_act=_valid_type_checks([str], "TransformerDecoder"), + post_layernorm_residual=Validator.check_bool, + layernorm_compute_type=_valid_value_checks([mstype.float32, mstype.float16], + "TransformerDecoder"), + softmax_comptue_type=_valid_value_checks([mstype.float32, mstype.float16], + "TransformerDecoder"), + param_init_type=_valid_value_checks([mstype.float32, mstype.float16], + "TransformerDecoder"), + parallel_config=_valid_type_checks([TransformerOpParallelConfig], + "TransformerDecoder"), + use_past=Validator.check_bool) def __init__(self, num_layers, batch_size, @@ -1608,15 +1689,6 @@ class TransformerDecoder(Cell): parallel_config=default_transformer_config): super(TransformerDecoder, self).__init__() _check_config(parallel_config) - Validator.check_positive_int(num_layers, "num_layers") - Validator.check_non_negative_int(offset, "offset") - if num_heads % parallel_config.model_parallel != 0: - raise ValueError( - f"num heads must be divisibled by the model parallel way {parallel_config.model_parallel}," - f"but found {num_heads}") - Validator.check_bool(post_layernorm_residual, "post_layernorm_residual") - if not isinstance(hidden_act, str): - raise ValueError(f"The hidden_act should be a str type, but found {type(hidden_act)}") self.num_layers = num_layers self.blocks = nn.CellList() @@ -1762,6 +1834,25 @@ class Transformer(Cell): """ + @_args_type_validator_check(batch_size=Validator.check_positive_int, + hidden_size=Validator.check_positive_int, + num_heads=Validator.check_positive_int, + ffn_hidden_size=Validator.check_positive_int, + src_seq_length=Validator.check_positive_int, + encoder_layers=Validator.check_positive_int, + decoder_layers=Validator.check_non_negative_int, + tgt_seq_length=Validator.check_positive_int, + attention_dropout_rate=Validator.check_non_negative_float, + hidden_dropout_rate=Validator.check_non_negative_float, + hidden_act=_valid_type_checks([str], "Transformer"), + post_layernorm_residual=Validator.check_bool, + layernorm_compute_type=_valid_type_checks([mstype.float32, mstype.float16], + "Transformer"), + softmax_comptue_type=_valid_type_checks([mstype.float32, mstype.float16], + "Transformer"), + param_init_type=_valid_type_checks([mstype.float32, mstype.float16], "Transformer"), + parallel_config=_valid_type_checks([TransformerOpParallelConfig], "Transformer"), + use_past=Validator.check_bool) def __init__(self, hidden_size, batch_size, diff --git a/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha.py b/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha.py index 2b461e9495f..d2e943d18fb 100644 --- a/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha.py +++ b/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha.py @@ -207,7 +207,7 @@ def set_parallel_configure_for_layer(network, layer_id, offset, parallel_config, print(f"pipeline stage id is {pp_id}", flush=True) # Used for optimizer's fusion tag - dis = max(int(layers / parallel_config.gradient_aggregation_group), 1) + dis = max(int((layers + 1) / parallel_config.gradient_aggregation_group), 1) if parallel_config.pipeline_stage > 1: # we give the fusion in pipeline mode a fixed value, otherwise the performance may become worse. network.set_comm_fusion(2) diff --git a/tests/ut/python/nn/test_transformer.py b/tests/ut/python/nn/test_transformer.py index 7eb30315108..f8dc5a53a0a 100644 --- a/tests/ut/python/nn/test_transformer.py +++ b/tests/ut/python/nn/test_transformer.py @@ -25,7 +25,7 @@ from mindspore.common.api import _cell_graph_executor def test_transformer_encoder_only(): model = Transformer(batch_size=2, src_seq_length=20, - tgt_seq_length=0, + tgt_seq_length=10, encoder_layers=2, decoder_layers=0, hidden_size=64, @@ -41,7 +41,7 @@ def test_transformer_encoder_log_softmax(): with pytest.raises(ValueError): model = Transformer(batch_size=2, src_seq_length=20, - tgt_seq_length=0, + tgt_seq_length=10, encoder_layers=2, decoder_layers=0, hidden_act='logsoftmax', @@ -57,7 +57,7 @@ def test_transformer_encoder_log_softmax(): def test_transformer_encoder_leakyrelu(): model = Transformer(batch_size=2, src_seq_length=20, - tgt_seq_length=0, + tgt_seq_length=10, encoder_layers=2, decoder_layers=0, hidden_act='leakyrelu', @@ -73,7 +73,7 @@ def test_transformer_encoder_leakyrelu(): def test_transformer_encoder_logsigmoid(): model = Transformer(batch_size=2, src_seq_length=20, - tgt_seq_length=0, + tgt_seq_length=10, encoder_layers=2, decoder_layers=0, hidden_act='logsigmoid', diff --git a/tests/ut/python/parallel/test_parallel_transformer.py b/tests/ut/python/parallel/test_parallel_transformer.py index 73b33ddf645..c3c250c1ea4 100644 --- a/tests/ut/python/parallel/test_parallel_transformer.py +++ b/tests/ut/python/parallel/test_parallel_transformer.py @@ -156,6 +156,37 @@ def test_transformer_model(): model.train(1, dataset, dataset_sink_mode=False) +def test_transformer_model_int64_inputs(): + set_auto_parallel_context(device_num=8, global_rank=0, + full_batch=True, + parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL) + net = Transformer(encoder_layers=1, + decoder_layers=2, + batch_size=2, + src_seq_length=20, + tgt_seq_length=10, + hidden_size=64, + num_heads=8, + ffn_hidden_size=64, + parallel_config=config) + + encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.int64) + encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16) + decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32) + decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16) + memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16) + net = NetWithLossFiveInputs(net) + params = net.trainable_params() + optimizer = AdamWeightDecay(params) + dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask, + memory_mask) + net_with_grad = TrainOneStepCell(net, optimizer=optimizer) + model = Model(net_with_grad) + + with pytest.raises(TypeError): + model.train(1, dataset, dataset_sink_mode=False) + + def test_transformer_model_head_parallel_only_encoder(): local_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8) run_total_transformer_model_head(e_layer=2, d_layer=0, arg_parallel_config=local_config) @@ -453,6 +484,33 @@ def test_parallel_cross_entroy_loss_semi_auto_parallel(): model.train(1, dataset, dataset_sink_mode=False) +def test_transformer_args(): + + with pytest.raises(TypeError): + Transformer(hidden_size=10, batch_size=2, ffn_hidden_size=20, src_seq_length=10, + tgt_seq_length=20, decoder_layers="aa") + + with pytest.raises(TypeError): + Transformer(hidden_size=10, batch_size=2, ffn_hidden_size=20, src_seq_length=10, + tgt_seq_length="a") + + with pytest.raises(TypeError): + Transformer(hidden_size=10, batch_size=2, ffn_hidden_size=20, src_seq_length=10, + tgt_seq_length=20, softmax_comptue_type=mstype.int64) + + with pytest.raises(TypeError): + Transformer(hidden_size=10, batch_size=2, ffn_hidden_size=20, src_seq_length=10, + tgt_seq_length=20, layernorm_compute_type=mstype.int64) + + with pytest.raises(TypeError): + Transformer(hidden_size=10, batch_size=2, ffn_hidden_size=20, src_seq_length=10, + tgt_seq_length=20, param_init_type=mstype.int64) + + with pytest.raises(TypeError): + Transformer(hidden_size=10, batch_size=2, ffn_hidden_size=20, src_seq_length=10, + tgt_seq_length=20, hidden_dropout_rate=mstype.int64) + + def test_transformer_parallel_config(): parallel_test_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=3)