forked from mindspore-Ecosystem/mindspore
!23004 Add args Check for Transformer
Merge pull request !23004 from huangxinjing/args_check
This commit is contained in:
commit
77424eaad5
|
@ -16,7 +16,8 @@
|
|||
The basic layer of the Transformer Networks. This is an experimental interface that is subject to
|
||||
change and/or deletion.
|
||||
"""
|
||||
|
||||
from functools import wraps, partial
|
||||
import inspect
|
||||
import math
|
||||
import numpy as np
|
||||
from mindspore.common.parameter import Parameter
|
||||
|
@ -37,6 +38,53 @@ __all__ = [
|
|||
]
|
||||
|
||||
|
||||
def _args_type_validator_check(*type_args, **type_kwargs):
|
||||
"""Check whether input data type is correct."""
|
||||
|
||||
def type_check(func):
|
||||
sig = inspect.signature(func)
|
||||
bound_types = sig.bind_partial(*type_args, **type_kwargs).arguments
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
nonlocal bound_types
|
||||
bound_values = sig.bind(*args, **kwargs)
|
||||
|
||||
argument_dict = bound_values.arguments
|
||||
if "kwargs" in bound_types:
|
||||
bound_types = bound_types["kwargs"]
|
||||
if "kwargs" in argument_dict:
|
||||
argument_dict = argument_dict["kwargs"]
|
||||
for name, value in argument_dict.items():
|
||||
if name in bound_types:
|
||||
bound_types[name](value, name)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return type_check
|
||||
|
||||
|
||||
def _valid_type_checks(types, class_name):
|
||||
# types should be a list of types, this function check if the type is in the valid dtypes
|
||||
def validator_check_func(value, name):
|
||||
# The args of Validator.check_type_name is (arg_name, arg_type, valid_types, prim_name)
|
||||
# as the input of _args_type_validator_check is fixed, so we need to manually change the input order
|
||||
partial_check = partial(Validator.check_type_name, valid_types=types, prim_name=class_name)
|
||||
return partial_check(name, type(value))
|
||||
return validator_check_func
|
||||
|
||||
|
||||
def _valid_value_checks(types, class_name):
|
||||
# the value should be a list of types, this function check if the value is in the valid dtypes
|
||||
def validator_check_func(value, name):
|
||||
# The args of Validator.check_type_name is (arg_name, arg_type, valid_types, prim_name)
|
||||
# as the input of _args_type_validator_check is fixed, so we need to manually change the input order
|
||||
partial_check = partial(Validator.check_type_name, valid_types=types, prim_name=class_name)
|
||||
return partial_check(name, value)
|
||||
return validator_check_func
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_input_shape(input_shape, param_name, func_name, target_len):
|
||||
if len(input_shape) != target_len:
|
||||
|
@ -339,7 +387,13 @@ class FixedSparseAttention(nn.Cell):
|
|||
>>> print(output.shape)
|
||||
(2, 1024, 512)
|
||||
"""
|
||||
|
||||
@_args_type_validator_check(batch_size=Validator.check_positive_int,
|
||||
num_heads=Validator.check_positive_int,
|
||||
size_per_head=Validator.check_positive_int,
|
||||
block_size=Validator.check_positive_int,
|
||||
seq_length=Validator.check_positive_int,
|
||||
num_different_global_patterns=Validator.check_positive_int,
|
||||
parallel_config=_valid_type_checks([OpParallelConfig], "FixedSparseAttention"))
|
||||
def __init__(self,
|
||||
batch_size,
|
||||
num_heads,
|
||||
|
@ -349,15 +403,6 @@ class FixedSparseAttention(nn.Cell):
|
|||
num_different_global_patterns=4,
|
||||
parallel_config=default_dpmp_config):
|
||||
super(FixedSparseAttention, self).__init__()
|
||||
if not isinstance(parallel_config, OpParallelConfig):
|
||||
raise TypeError(
|
||||
f"The parallel_config should be a OpParallelConfig type, but found {type(parallel_config)}")
|
||||
Validator.check_positive_int(batch_size, "batch_size")
|
||||
Validator.check_positive_int(num_heads, "num_heads")
|
||||
Validator.check_positive_int(size_per_head, "size_per_head")
|
||||
Validator.check_positive_int(block_size, "block_size")
|
||||
Validator.check_positive_int(seq_length, "seq_length")
|
||||
Validator.check_positive_int(num_different_global_patterns, "num_different_global_patterns")
|
||||
dp, mp = parallel_config.data_parallel, parallel_config.model_parallel
|
||||
if num_heads % mp != 0:
|
||||
raise ValueError(f"The number of heads {num_heads} must be a "
|
||||
|
|
|
@ -29,8 +29,9 @@ from mindspore.ops import functional as F
|
|||
from mindspore.nn.cell import Cell
|
||||
from mindspore._checkparam import Validator
|
||||
from mindspore import log as logger
|
||||
from .layers import _LayerNorm, _Linear, _check_input_shape,\
|
||||
_check_shape_equal, _check_past_none_input_none, _check_input_dtype, _check_input_shape_value
|
||||
from .layers import _LayerNorm, _Linear, _check_input_shape, \
|
||||
_check_shape_equal, _check_past_none_input_none, _check_input_dtype, _check_input_shape_value, \
|
||||
_args_type_validator_check, _valid_type_checks, _valid_value_checks
|
||||
from .op_parallel_config import default_dpmp_config, _PipeLineConfig, OpParallelConfig, _Config, _check_config
|
||||
|
||||
__all__ = [
|
||||
|
@ -292,6 +293,14 @@ class FeedForward(Cell):
|
|||
(2, 20, 15)
|
||||
"""
|
||||
|
||||
@_args_type_validator_check(hidden_size=Validator.check_positive_int,
|
||||
ffn_hidden_size=Validator.check_positive_int,
|
||||
dropout_rate=Validator.check_non_negative_float,
|
||||
hidden_act=_valid_type_checks([str], "FeedForward"),
|
||||
param_init_type=_valid_value_checks([mstype.float32, mstype.float16],
|
||||
"FeedForward"),
|
||||
parallel_config=_valid_type_checks([OpParallelConfig],
|
||||
"FeedForward"))
|
||||
def __init__(self, hidden_size,
|
||||
ffn_hidden_size,
|
||||
dropout_rate,
|
||||
|
@ -300,13 +309,6 @@ class FeedForward(Cell):
|
|||
parallel_config=default_dpmp_config):
|
||||
super(FeedForward, self).__init__()
|
||||
_check_config(parallel_config)
|
||||
Validator.check_positive_int(hidden_size, "hidden_size")
|
||||
Validator.check_positive_int(ffn_hidden_size, "ffn_hidden_size")
|
||||
if not isinstance(hidden_act, str):
|
||||
raise ValueError(f"The hidden_act should be a str type, but found {type(hidden_act)}")
|
||||
if not isinstance(parallel_config, OpParallelConfig):
|
||||
raise ValueError(
|
||||
f"The parallel_config should be a OpParallelConfig type, but found {type(parallel_config)}")
|
||||
dp = parallel_config.data_parallel
|
||||
mp = parallel_config.model_parallel
|
||||
if ffn_hidden_size % mp != 0:
|
||||
|
@ -388,12 +390,10 @@ class AttentionMask(Cell):
|
|||
[0, 0, 0, 0]]])
|
||||
"""
|
||||
|
||||
@_args_type_validator_check(seq_length=Validator.check_positive_int,
|
||||
parallel_config=_valid_type_checks([OpParallelConfig], "AttentionMask"))
|
||||
def __init__(self, seq_length, parallel_config=default_dpmp_config):
|
||||
super(AttentionMask, self).__init__()
|
||||
Validator.check_positive_int(seq_length, "seq_length")
|
||||
if not isinstance(parallel_config, OpParallelConfig):
|
||||
raise ValueError(
|
||||
f"The parallel_config should be a OpParallelConfig type, but found {type(parallel_config)}")
|
||||
self.seq_length = seq_length
|
||||
self.not_equal = P.NotEqual().shard(((parallel_config.data_parallel, 1), ()))
|
||||
self.reshape = P.Reshape()
|
||||
|
@ -470,15 +470,13 @@ class VocabEmbedding(Cell):
|
|||
(30, 30)
|
||||
"""
|
||||
|
||||
@_args_type_validator_check(vocab_size=Validator.check_positive_int,
|
||||
embedding_size=Validator.check_positive_int,
|
||||
parallel_config=_valid_type_checks([EmbeddingOpParallelConfig], "VocabEmbedding"))
|
||||
def __init__(self, vocab_size, embedding_size, parallel_config=default_embedding_parallel_config,
|
||||
param_init='normal'):
|
||||
super(VocabEmbedding, self).__init__()
|
||||
_check_config(parallel_config)
|
||||
Validator.check_positive_int(vocab_size, "vocab_size")
|
||||
Validator.check_positive_int(embedding_size, "embedding_size")
|
||||
if not isinstance(parallel_config, EmbeddingOpParallelConfig):
|
||||
raise ValueError(f"The parallel_config should be a VocabEmbedding type, but found {type(parallel_config)}")
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.embedding_size = embedding_size
|
||||
self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]),
|
||||
|
@ -564,6 +562,20 @@ class MultiHeadAttention(Cell):
|
|||
(2, 3, 20, 5)
|
||||
"""
|
||||
|
||||
@_args_type_validator_check(batch_size=Validator.check_positive_int,
|
||||
hidden_size=Validator.check_positive_int,
|
||||
num_heads=Validator.check_positive_int,
|
||||
src_seq_length=Validator.check_positive_int,
|
||||
tgt_seq_length=Validator.check_positive_int,
|
||||
attention_dropout_rate=Validator.check_non_negative_float,
|
||||
hidden_dropout_rate=Validator.check_non_negative_float,
|
||||
softmax_comptue_type=_valid_value_checks([mstype.float32, mstype.float16],
|
||||
"MultiHeadAttention"),
|
||||
param_init_type=_valid_value_checks([mstype.float32, mstype.float16],
|
||||
"MultiHeadAttention"),
|
||||
parallel_config=_valid_type_checks([OpParallelConfig],
|
||||
"MultiHeadAttention"),
|
||||
use_past=Validator.check_bool)
|
||||
def __init__(self, batch_size,
|
||||
src_seq_length,
|
||||
tgt_seq_length,
|
||||
|
@ -582,11 +594,6 @@ class MultiHeadAttention(Cell):
|
|||
self.tgt_seq_length = tgt_seq_length
|
||||
self.hidden_size = hidden_size
|
||||
self.batch_size = batch_size
|
||||
Validator.check_positive_int(num_heads, "num_heads")
|
||||
Validator.check_positive_int(batch_size, "batch_size")
|
||||
Validator.check_positive_int(src_seq_length, "src_seq_length")
|
||||
Validator.check_positive_int(tgt_seq_length, "tgt_seq_length")
|
||||
Validator.check_positive_int(num_heads, "num_heads")
|
||||
if hidden_dropout_rate < 0 or hidden_dropout_rate >= 1:
|
||||
raise ValueError("hidden_dropout_rate probability should be a number in range [0, 1.0), "
|
||||
"but got {}".format(hidden_dropout_rate))
|
||||
|
@ -787,11 +794,13 @@ class MultiHeadAttention(Cell):
|
|||
_check_input_dtype(F.dtype(query_tensor), "query_tensor", [mstype.float32, mstype.float16], self.cls_name)
|
||||
_check_input_dtype(F.dtype(key_tensor), "key_tensor", [mstype.float32, mstype.float16], self.cls_name)
|
||||
_check_input_dtype(F.dtype(value_tensor), "value_tensor", [mstype.float32, mstype.float16], self.cls_name)
|
||||
_check_input_dtype(F.dtype(attention_mask), "attention_mask", [mstype.float32, mstype.float16], self.cls_name)
|
||||
|
||||
_check_past_none_input_none(self.use_past, "key_past", self.cls_name, key_past)
|
||||
_check_past_none_input_none(self.use_past, "value_past", self.cls_name, value_past)
|
||||
_check_past_none_input_none(self.use_past, "batch_valid_length", self.cls_name, batch_valid_length)
|
||||
return True
|
||||
|
||||
def _merge_heads(self, x):
|
||||
"""
|
||||
convert a 4d input to a 3d output
|
||||
|
@ -932,6 +941,24 @@ class TransformerEncoderLayer(Cell):
|
|||
(2, 2, 16, 4)
|
||||
"""
|
||||
|
||||
@_args_type_validator_check(batch_size=Validator.check_positive_int,
|
||||
hidden_size=Validator.check_positive_int,
|
||||
num_heads=Validator.check_positive_int,
|
||||
ffn_hidden_size=Validator.check_positive_int,
|
||||
seq_length=Validator.check_positive_int,
|
||||
attention_dropout_rate=Validator.check_non_negative_float,
|
||||
hidden_dropout_rate=Validator.check_non_negative_float,
|
||||
hidden_act=_valid_type_checks([str], "TransformerEncoderLayer"),
|
||||
post_layernorm_residual=Validator.check_bool,
|
||||
layernorm_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
|
||||
"TransformerEncoderLayer"),
|
||||
softmax_comptue_type=_valid_value_checks([mstype.float32, mstype.float16],
|
||||
"TransformerEncoderLayer"),
|
||||
param_init_type=_valid_value_checks([mstype.float32, mstype.float16],
|
||||
"TransformerEncoderLayer"),
|
||||
parallel_config=_valid_type_checks([OpParallelConfig],
|
||||
"TransformerEncoderLayer"),
|
||||
use_past=Validator.check_bool)
|
||||
def __init__(self,
|
||||
batch_size,
|
||||
hidden_size,
|
||||
|
@ -953,9 +980,6 @@ class TransformerEncoderLayer(Cell):
|
|||
raise ValueError(
|
||||
f"num heads must be divisibled by the model parallel way {parallel_config.model_parallel},"
|
||||
f"but found {num_heads}")
|
||||
Validator.check_bool(post_layernorm_residual, "post_layernorm_residual")
|
||||
if not isinstance(hidden_act, str):
|
||||
raise ValueError(f"The hidden_act should be a str type, but found {type(hidden_act)}")
|
||||
self.use_past = use_past
|
||||
self.seq_length = seq_length
|
||||
self.hidden_size = hidden_size
|
||||
|
@ -986,6 +1010,8 @@ class TransformerEncoderLayer(Cell):
|
|||
self.post_layernorm_residual = post_layernorm_residual
|
||||
self.add = P.TensorAdd().shard(((parallel_config.data_parallel, 1, 1), (parallel_config.data_parallel, 1, 1)))
|
||||
self.dtype = mstype.float16
|
||||
self.key_past = None
|
||||
self.value_past = None
|
||||
|
||||
if self.use_past:
|
||||
# operator used for state reuse
|
||||
|
@ -1058,8 +1084,10 @@ class TransformerEncoderLayer(Cell):
|
|||
r"""Check inputs"""
|
||||
_check_shape_equal(F.shape(x), "x", self.cls_name,
|
||||
[self.batch_size, self.seq_length, self.hidden_size])
|
||||
_check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16], self.cls_name)
|
||||
_check_shape_equal(F.shape(input_mask), "input_mask", self.cls_name,
|
||||
[self.batch_size, self.seq_length, self.seq_length])
|
||||
_check_input_dtype(F.dtype(input_mask), "input_mask", [mstype.float32, mstype.float16], self.cls_name)
|
||||
_check_past_none_input_none(self.use_past, "init_reset", self.cls_name, init_reset, True)
|
||||
if init_reset is not True:
|
||||
_check_input_dtype(F.dtype(init_reset), "init_reset", [mstype.bool_], self.cls_name)
|
||||
|
@ -1143,6 +1171,25 @@ class TransformerDecoderLayer(Cell):
|
|||
(2, 2, 20, 32)
|
||||
"""
|
||||
|
||||
@_args_type_validator_check(batch_size=Validator.check_positive_int,
|
||||
hidden_size=Validator.check_positive_int,
|
||||
num_heads=Validator.check_positive_int,
|
||||
ffn_hidden_size=Validator.check_positive_int,
|
||||
src_seq_length=Validator.check_positive_int,
|
||||
tgt_seq_length=Validator.check_positive_int,
|
||||
attention_dropout_rate=Validator.check_non_negative_float,
|
||||
hidden_dropout_rate=Validator.check_non_negative_float,
|
||||
hidden_act=_valid_type_checks([str], "TransformerDecoderLayer"),
|
||||
post_layernorm_residual=Validator.check_bool,
|
||||
layernorm_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
|
||||
"TransformerDecoderLayer"),
|
||||
softmax_comptue_type=_valid_value_checks([mstype.float32, mstype.float16],
|
||||
"TransformerDecoderLayer"),
|
||||
param_init_type=_valid_value_checks([mstype.float32, mstype.float16],
|
||||
"TransformerDecoderLayer"),
|
||||
parallel_config=_valid_type_checks([OpParallelConfig],
|
||||
"TransformerDecoderLayer"),
|
||||
use_past=Validator.check_bool)
|
||||
def __init__(self, hidden_size,
|
||||
ffn_hidden_size,
|
||||
num_heads,
|
||||
|
@ -1163,13 +1210,7 @@ class TransformerDecoderLayer(Cell):
|
|||
self.batch_size = batch_size
|
||||
self.use_past = use_past
|
||||
self.softmax_comptue_type = softmax_comptue_type
|
||||
if num_heads % parallel_config.model_parallel != 0:
|
||||
raise ValueError(
|
||||
f"num heads must be divisibled by the model parallel way {parallel_config.model_parallel},"
|
||||
f"but found {num_heads}")
|
||||
Validator.check_bool(post_layernorm_residual, "post_layernorm_residual")
|
||||
if not isinstance(hidden_act, str):
|
||||
raise ValueError(f"The hidden_act should be a str type, but found {type(hidden_act)}")
|
||||
|
||||
self.src_seq_length = src_seq_length
|
||||
self.tgt_seq_length = tgt_seq_length
|
||||
self.use_past = use_past
|
||||
|
@ -1217,6 +1258,8 @@ class TransformerDecoderLayer(Cell):
|
|||
self.post_layernorm_residual = post_layernorm_residual
|
||||
self.add = P.TensorAdd().shard(((parallel_config.data_parallel, 1, 1), (parallel_config.data_parallel, 1, 1)))
|
||||
self.dtype = mstype.float16
|
||||
self.key_past = None
|
||||
self.value_past = None
|
||||
if self.use_past:
|
||||
# operator used for state reuse
|
||||
self.reducesum = P.ReduceSum().shard(((1, 1, 1, 1),))
|
||||
|
@ -1308,12 +1351,18 @@ class TransformerDecoderLayer(Cell):
|
|||
[self.batch_size, self.tgt_seq_length, self.hidden_size])
|
||||
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
|
||||
[self.batch_size, self.tgt_seq_length, self.tgt_seq_length])
|
||||
_check_input_dtype(F.dtype(hidden_states), "hidden_size", [mstype.float32, mstype.float16], self.cls_name)
|
||||
_check_input_dtype(F.dtype(attention_mask), "attention_mask", [mstype.float32, mstype.float16], self.cls_name)
|
||||
if encoder_output is not None:
|
||||
_check_shape_equal(F.shape(encoder_output), "encoder_output", self.cls_name,
|
||||
[self.batch_size, self.src_seq_length, self.hidden_size])
|
||||
_check_input_dtype(F.dtype(encoder_output), "encoder_output",
|
||||
[mstype.float32, mstype.float16], self.cls_name)
|
||||
if memory_mask is not None:
|
||||
_check_shape_equal(F.shape(memory_mask), "memory_mask", self.cls_name,
|
||||
[self.batch_size, self.tgt_seq_length, self.src_seq_length])
|
||||
_check_input_dtype(F.dtype(memory_mask), "memory_mask",
|
||||
[mstype.float32, mstype.float16], self.cls_name)
|
||||
|
||||
_check_past_none_input_none(self.use_past, "init_reset", self.cls_name, init_reset, True)
|
||||
if init_reset is not True:
|
||||
|
@ -1437,6 +1486,26 @@ class TransformerEncoder(Cell):
|
|||
(2, 2, 16, 4)
|
||||
"""
|
||||
|
||||
@_args_type_validator_check(batch_size=Validator.check_positive_int,
|
||||
hidden_size=Validator.check_positive_int,
|
||||
num_heads=Validator.check_positive_int,
|
||||
ffn_hidden_size=Validator.check_positive_int,
|
||||
seq_length=Validator.check_positive_int,
|
||||
num_layers=Validator.check_positive_int,
|
||||
offset=Validator.check_non_negative_int,
|
||||
attention_dropout_rate=Validator.check_non_negative_float,
|
||||
hidden_dropout_rate=Validator.check_non_negative_float,
|
||||
hidden_act=_valid_type_checks([str], "TransformerEncoder"),
|
||||
post_layernorm_residual=Validator.check_bool,
|
||||
layernorm_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
|
||||
"TransformerEncoder"),
|
||||
softmax_comptue_type=_valid_value_checks([mstype.float32, mstype.float16],
|
||||
"TransformerEncoder"),
|
||||
param_init_type=_valid_value_checks([mstype.float32, mstype.float16],
|
||||
"TransformerEncoder"),
|
||||
parallel_config=_valid_type_checks([TransformerOpParallelConfig],
|
||||
"TransformerEncoder"),
|
||||
use_past=Validator.check_bool)
|
||||
def __init__(self,
|
||||
batch_size,
|
||||
num_layers,
|
||||
|
@ -1457,15 +1526,6 @@ class TransformerEncoder(Cell):
|
|||
parallel_config=default_transformer_config):
|
||||
super(TransformerEncoder, self).__init__()
|
||||
_check_config(parallel_config)
|
||||
Validator.check_positive_int(num_layers, "num_layers")
|
||||
Validator.check_non_negative_int(offset, "offset")
|
||||
if num_heads % parallel_config.model_parallel != 0:
|
||||
raise ValueError(
|
||||
f"num heads must be divisibled by the model parallel way {parallel_config.model_parallel},"
|
||||
f"but found {num_heads}")
|
||||
Validator.check_bool(post_layernorm_residual, "post_layernorm_residual")
|
||||
if not isinstance(hidden_act, str):
|
||||
raise ValueError(f"The hidden_act should be a str type, but found {type(hidden_act)}")
|
||||
|
||||
self.num_layers = num_layers
|
||||
self.blocks = nn.CellList()
|
||||
|
@ -1587,6 +1647,27 @@ class TransformerDecoder(Cell):
|
|||
|
||||
"""
|
||||
|
||||
@_args_type_validator_check(batch_size=Validator.check_positive_int,
|
||||
hidden_size=Validator.check_positive_int,
|
||||
num_heads=Validator.check_positive_int,
|
||||
ffn_hidden_size=Validator.check_positive_int,
|
||||
src_seq_length=Validator.check_positive_int,
|
||||
num_layers=Validator.check_positive_int,
|
||||
tgt_seq_length=Validator.check_positive_int,
|
||||
offset=Validator.check_non_negative_int,
|
||||
attention_dropout_rate=Validator.check_non_negative_float,
|
||||
hidden_dropout_rate=Validator.check_non_negative_float,
|
||||
hidden_act=_valid_type_checks([str], "TransformerDecoder"),
|
||||
post_layernorm_residual=Validator.check_bool,
|
||||
layernorm_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
|
||||
"TransformerDecoder"),
|
||||
softmax_comptue_type=_valid_value_checks([mstype.float32, mstype.float16],
|
||||
"TransformerDecoder"),
|
||||
param_init_type=_valid_value_checks([mstype.float32, mstype.float16],
|
||||
"TransformerDecoder"),
|
||||
parallel_config=_valid_type_checks([TransformerOpParallelConfig],
|
||||
"TransformerDecoder"),
|
||||
use_past=Validator.check_bool)
|
||||
def __init__(self,
|
||||
num_layers,
|
||||
batch_size,
|
||||
|
@ -1608,15 +1689,6 @@ class TransformerDecoder(Cell):
|
|||
parallel_config=default_transformer_config):
|
||||
super(TransformerDecoder, self).__init__()
|
||||
_check_config(parallel_config)
|
||||
Validator.check_positive_int(num_layers, "num_layers")
|
||||
Validator.check_non_negative_int(offset, "offset")
|
||||
if num_heads % parallel_config.model_parallel != 0:
|
||||
raise ValueError(
|
||||
f"num heads must be divisibled by the model parallel way {parallel_config.model_parallel},"
|
||||
f"but found {num_heads}")
|
||||
Validator.check_bool(post_layernorm_residual, "post_layernorm_residual")
|
||||
if not isinstance(hidden_act, str):
|
||||
raise ValueError(f"The hidden_act should be a str type, but found {type(hidden_act)}")
|
||||
|
||||
self.num_layers = num_layers
|
||||
self.blocks = nn.CellList()
|
||||
|
@ -1762,6 +1834,25 @@ class Transformer(Cell):
|
|||
|
||||
"""
|
||||
|
||||
@_args_type_validator_check(batch_size=Validator.check_positive_int,
|
||||
hidden_size=Validator.check_positive_int,
|
||||
num_heads=Validator.check_positive_int,
|
||||
ffn_hidden_size=Validator.check_positive_int,
|
||||
src_seq_length=Validator.check_positive_int,
|
||||
encoder_layers=Validator.check_positive_int,
|
||||
decoder_layers=Validator.check_non_negative_int,
|
||||
tgt_seq_length=Validator.check_positive_int,
|
||||
attention_dropout_rate=Validator.check_non_negative_float,
|
||||
hidden_dropout_rate=Validator.check_non_negative_float,
|
||||
hidden_act=_valid_type_checks([str], "Transformer"),
|
||||
post_layernorm_residual=Validator.check_bool,
|
||||
layernorm_compute_type=_valid_type_checks([mstype.float32, mstype.float16],
|
||||
"Transformer"),
|
||||
softmax_comptue_type=_valid_type_checks([mstype.float32, mstype.float16],
|
||||
"Transformer"),
|
||||
param_init_type=_valid_type_checks([mstype.float32, mstype.float16], "Transformer"),
|
||||
parallel_config=_valid_type_checks([TransformerOpParallelConfig], "Transformer"),
|
||||
use_past=Validator.check_bool)
|
||||
def __init__(self,
|
||||
hidden_size,
|
||||
batch_size,
|
||||
|
|
|
@ -207,7 +207,7 @@ def set_parallel_configure_for_layer(network, layer_id, offset, parallel_config,
|
|||
print(f"pipeline stage id is {pp_id}", flush=True)
|
||||
|
||||
# Used for optimizer's fusion tag
|
||||
dis = max(int(layers / parallel_config.gradient_aggregation_group), 1)
|
||||
dis = max(int((layers + 1) / parallel_config.gradient_aggregation_group), 1)
|
||||
if parallel_config.pipeline_stage > 1:
|
||||
# we give the fusion in pipeline mode a fixed value, otherwise the performance may become worse.
|
||||
network.set_comm_fusion(2)
|
||||
|
|
|
@ -25,7 +25,7 @@ from mindspore.common.api import _cell_graph_executor
|
|||
def test_transformer_encoder_only():
|
||||
model = Transformer(batch_size=2,
|
||||
src_seq_length=20,
|
||||
tgt_seq_length=0,
|
||||
tgt_seq_length=10,
|
||||
encoder_layers=2,
|
||||
decoder_layers=0,
|
||||
hidden_size=64,
|
||||
|
@ -41,7 +41,7 @@ def test_transformer_encoder_log_softmax():
|
|||
with pytest.raises(ValueError):
|
||||
model = Transformer(batch_size=2,
|
||||
src_seq_length=20,
|
||||
tgt_seq_length=0,
|
||||
tgt_seq_length=10,
|
||||
encoder_layers=2,
|
||||
decoder_layers=0,
|
||||
hidden_act='logsoftmax',
|
||||
|
@ -57,7 +57,7 @@ def test_transformer_encoder_log_softmax():
|
|||
def test_transformer_encoder_leakyrelu():
|
||||
model = Transformer(batch_size=2,
|
||||
src_seq_length=20,
|
||||
tgt_seq_length=0,
|
||||
tgt_seq_length=10,
|
||||
encoder_layers=2,
|
||||
decoder_layers=0,
|
||||
hidden_act='leakyrelu',
|
||||
|
@ -73,7 +73,7 @@ def test_transformer_encoder_leakyrelu():
|
|||
def test_transformer_encoder_logsigmoid():
|
||||
model = Transformer(batch_size=2,
|
||||
src_seq_length=20,
|
||||
tgt_seq_length=0,
|
||||
tgt_seq_length=10,
|
||||
encoder_layers=2,
|
||||
decoder_layers=0,
|
||||
hidden_act='logsigmoid',
|
||||
|
|
|
@ -156,6 +156,37 @@ def test_transformer_model():
|
|||
model.train(1, dataset, dataset_sink_mode=False)
|
||||
|
||||
|
||||
def test_transformer_model_int64_inputs():
|
||||
set_auto_parallel_context(device_num=8, global_rank=0,
|
||||
full_batch=True,
|
||||
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
|
||||
net = Transformer(encoder_layers=1,
|
||||
decoder_layers=2,
|
||||
batch_size=2,
|
||||
src_seq_length=20,
|
||||
tgt_seq_length=10,
|
||||
hidden_size=64,
|
||||
num_heads=8,
|
||||
ffn_hidden_size=64,
|
||||
parallel_config=config)
|
||||
|
||||
encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.int64)
|
||||
encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16)
|
||||
decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32)
|
||||
decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
|
||||
memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
|
||||
net = NetWithLossFiveInputs(net)
|
||||
params = net.trainable_params()
|
||||
optimizer = AdamWeightDecay(params)
|
||||
dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
|
||||
memory_mask)
|
||||
net_with_grad = TrainOneStepCell(net, optimizer=optimizer)
|
||||
model = Model(net_with_grad)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
model.train(1, dataset, dataset_sink_mode=False)
|
||||
|
||||
|
||||
def test_transformer_model_head_parallel_only_encoder():
|
||||
local_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8)
|
||||
run_total_transformer_model_head(e_layer=2, d_layer=0, arg_parallel_config=local_config)
|
||||
|
@ -453,6 +484,33 @@ def test_parallel_cross_entroy_loss_semi_auto_parallel():
|
|||
model.train(1, dataset, dataset_sink_mode=False)
|
||||
|
||||
|
||||
def test_transformer_args():
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
Transformer(hidden_size=10, batch_size=2, ffn_hidden_size=20, src_seq_length=10,
|
||||
tgt_seq_length=20, decoder_layers="aa")
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
Transformer(hidden_size=10, batch_size=2, ffn_hidden_size=20, src_seq_length=10,
|
||||
tgt_seq_length="a")
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
Transformer(hidden_size=10, batch_size=2, ffn_hidden_size=20, src_seq_length=10,
|
||||
tgt_seq_length=20, softmax_comptue_type=mstype.int64)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
Transformer(hidden_size=10, batch_size=2, ffn_hidden_size=20, src_seq_length=10,
|
||||
tgt_seq_length=20, layernorm_compute_type=mstype.int64)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
Transformer(hidden_size=10, batch_size=2, ffn_hidden_size=20, src_seq_length=10,
|
||||
tgt_seq_length=20, param_init_type=mstype.int64)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
Transformer(hidden_size=10, batch_size=2, ffn_hidden_size=20, src_seq_length=10,
|
||||
tgt_seq_length=20, hidden_dropout_rate=mstype.int64)
|
||||
|
||||
|
||||
def test_transformer_parallel_config():
|
||||
parallel_test_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=3)
|
||||
|
||||
|
|
Loading…
Reference in New Issue