forked from mindspore-Ecosystem/mindspore
remove the redundant code, add docstring of operator init and add default value for args.
This commit is contained in:
parent
494639ad8e
commit
185ddbbe66
|
@ -85,6 +85,7 @@ class Softmax(Cell):
|
|||
"""
|
||||
|
||||
def __init__(self, axis=-1):
|
||||
"""Initialize Softmax."""
|
||||
super(Softmax, self).__init__()
|
||||
self.softmax = P.Softmax(axis)
|
||||
|
||||
|
@ -135,6 +136,7 @@ class LogSoftmax(Cell):
|
|||
"""
|
||||
|
||||
def __init__(self, axis=-1):
|
||||
"""Initialize LogSoftmax."""
|
||||
super(LogSoftmax, self).__init__()
|
||||
self.log_softmax = P.LogSoftmax(axis)
|
||||
|
||||
|
@ -185,6 +187,7 @@ class ELU(Cell):
|
|||
"""
|
||||
|
||||
def __init__(self, alpha=1.0):
|
||||
"""Initialize ELU."""
|
||||
super(ELU, self).__init__()
|
||||
self.elu = P.Elu(alpha)
|
||||
|
||||
|
@ -229,6 +232,7 @@ class ReLU(Cell):
|
|||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize ReLU."""
|
||||
super(ReLU, self).__init__()
|
||||
self.relu = P.ReLU()
|
||||
|
||||
|
@ -271,6 +275,7 @@ class ReLU6(Cell):
|
|||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize ReLU6."""
|
||||
super(ReLU6, self).__init__()
|
||||
self.relu6 = P.ReLU6()
|
||||
|
||||
|
@ -316,6 +321,7 @@ class LeakyReLU(Cell):
|
|||
"""
|
||||
|
||||
def __init__(self, alpha=0.2):
|
||||
"""Initialize LeakyReLU."""
|
||||
super(LeakyReLU, self).__init__()
|
||||
validator.check_value_type('alpha', alpha, [float, int], self.cls_name)
|
||||
self.greater_equal = P.GreaterEqual()
|
||||
|
@ -366,6 +372,7 @@ class Tanh(Cell):
|
|||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Tanh."""
|
||||
super(Tanh, self).__init__()
|
||||
self.tanh = P.Tanh()
|
||||
|
||||
|
@ -413,6 +420,7 @@ class GELU(Cell):
|
|||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize GELU."""
|
||||
super(GELU, self).__init__()
|
||||
self.gelu = P.GeLU()
|
||||
|
||||
|
@ -456,6 +464,7 @@ class FastGelu(Cell):
|
|||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize FastGelu."""
|
||||
super(FastGelu, self).__init__()
|
||||
self.fast_gelu = P.FastGeLU()
|
||||
|
||||
|
@ -501,6 +510,7 @@ class Sigmoid(Cell):
|
|||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Sigmoid."""
|
||||
super(Sigmoid, self).__init__()
|
||||
self.sigmoid = P.Sigmoid()
|
||||
|
||||
|
@ -560,6 +570,7 @@ class PReLU(Cell):
|
|||
"""
|
||||
@cell_attr_register(attrs="")
|
||||
def __init__(self, channel=1, w=0.25):
|
||||
"""Initialize PReLU."""
|
||||
super(PReLU, self).__init__()
|
||||
validator.check_positive_int(channel, 'channel', self.cls_name)
|
||||
if isinstance(w, (np.float32, float)):
|
||||
|
@ -619,6 +630,7 @@ class HSwish(Cell):
|
|||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize HSwish."""
|
||||
super(HSwish, self).__init__()
|
||||
self.hswish = P.HSwish()
|
||||
|
||||
|
@ -660,6 +672,7 @@ class HSigmoid(Cell):
|
|||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize HSigmoid."""
|
||||
super(HSigmoid, self).__init__()
|
||||
self.hsigmoid = P.HSigmoid()
|
||||
|
||||
|
@ -701,6 +714,7 @@ class LogSigmoid(Cell):
|
|||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize LogSigmoid."""
|
||||
super(LogSigmoid, self).__init__()
|
||||
self.mul = P.Mul()
|
||||
self.exp = P.Exp()
|
||||
|
|
|
@ -76,6 +76,7 @@ class L1Regularizer(Cell):
|
|||
"""
|
||||
|
||||
def __init__(self, scale):
|
||||
"""Initialize L1Regularizer."""
|
||||
super(L1Regularizer, self).__init__()
|
||||
Validator.check_value_type("scale", scale, [int, float], self.cls_name)
|
||||
if scale <= 0:
|
||||
|
@ -143,6 +144,7 @@ class Dropout(Cell):
|
|||
"""
|
||||
|
||||
def __init__(self, keep_prob=0.5, dtype=mstype.float32):
|
||||
"""Initialize Dropout."""
|
||||
super(Dropout, self).__init__()
|
||||
if keep_prob <= 0 or keep_prob > 1:
|
||||
raise ValueError("dropout probability should be a number in range (0, 1], but got {}".format(keep_prob))
|
||||
|
@ -197,6 +199,7 @@ class Flatten(Cell):
|
|||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Flatten."""
|
||||
super(Flatten, self).__init__()
|
||||
|
||||
def construct(self, x):
|
||||
|
@ -269,6 +272,7 @@ class Dense(Cell):
|
|||
bias_init='zeros',
|
||||
has_bias=True,
|
||||
activation=None):
|
||||
"""Initialize Dense."""
|
||||
super(Dense, self).__init__()
|
||||
self.in_channels = Validator.check_positive_int(in_channels)
|
||||
self.out_channels = Validator.check_positive_int(out_channels)
|
||||
|
@ -276,7 +280,6 @@ class Dense(Cell):
|
|||
self.reshape = P.Reshape()
|
||||
self.shape_op = P.Shape()
|
||||
|
||||
|
||||
if isinstance(weight_init, Tensor):
|
||||
if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \
|
||||
weight_init.shape[1] != in_channels:
|
||||
|
@ -390,6 +393,7 @@ class ClipByNorm(Cell):
|
|||
"""
|
||||
|
||||
def __init__(self, axis=None):
|
||||
"""Initialize ClipByNorm."""
|
||||
super(ClipByNorm, self).__init__()
|
||||
if axis is None:
|
||||
axis = ()
|
||||
|
@ -468,6 +472,7 @@ class Norm(Cell):
|
|||
"""
|
||||
|
||||
def __init__(self, axis=(), keep_dims=False):
|
||||
"""Initialize Norm."""
|
||||
super(Norm, self).__init__()
|
||||
Validator.check_value_type("keep_dims", keep_dims, [bool], self.cls_name)
|
||||
self.axis = axis
|
||||
|
@ -561,6 +566,7 @@ class OneHot(Cell):
|
|||
"""
|
||||
|
||||
def __init__(self, axis=-1, depth=1, on_value=1.0, off_value=0.0, dtype=mstype.float32):
|
||||
"""Initialize OneHot."""
|
||||
super(OneHot, self).__init__()
|
||||
self.onehot = P.OneHot(axis)
|
||||
self.depth = depth
|
||||
|
@ -633,6 +639,7 @@ class Pad(Cell):
|
|||
"""
|
||||
|
||||
def __init__(self, paddings, mode="CONSTANT"):
|
||||
"""Initialize Pad."""
|
||||
super(Pad, self).__init__()
|
||||
self.mode = mode
|
||||
self.paddings = paddings
|
||||
|
@ -722,6 +729,7 @@ class ResizeBilinear(Cell):
|
|||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize ResizeBilinear."""
|
||||
super(ResizeBilinear, self).__init__()
|
||||
|
||||
def construct(self, x, size=None, scale_factor=None, align_corners=False):
|
||||
|
@ -780,6 +788,7 @@ class Unfold(Cell):
|
|||
"""
|
||||
|
||||
def __init__(self, ksizes, strides, rates, padding="valid"):
|
||||
"""Initialize Unfold."""
|
||||
super(Unfold, self).__init__()
|
||||
|
||||
def _check_tuple_or_list(arg_name, arg_val, prim_name):
|
||||
|
@ -841,6 +850,7 @@ class Tril(Cell):
|
|||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Tril."""
|
||||
super(Tril, self).__init__()
|
||||
self.dtype = P.DType()
|
||||
self.mul = P.Mul()
|
||||
|
@ -888,6 +898,7 @@ class Triu(Cell):
|
|||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Triu."""
|
||||
super(Triu, self).__init__()
|
||||
self.dtype = P.DType()
|
||||
self.mul = P.Mul()
|
||||
|
@ -946,6 +957,7 @@ class MatrixDiag(Cell):
|
|||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize MatrixDiag."""
|
||||
super(MatrixDiag, self).__init__()
|
||||
self.matrix_diag = inner.MatrixDiag()
|
||||
self.dtype = P.DType()
|
||||
|
@ -990,6 +1002,7 @@ class MatrixDiagPart(Cell):
|
|||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize MatrixDiagPart."""
|
||||
super(MatrixDiagPart, self).__init__()
|
||||
self.matrix_diag_part = inner.MatrixDiagPart()
|
||||
self.dtype = P.DType()
|
||||
|
@ -1048,6 +1061,7 @@ class MatrixSetDiag(Cell):
|
|||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize MatrixSetDiag."""
|
||||
super(MatrixSetDiag, self).__init__()
|
||||
self.matrix_set_diag = inner.MatrixSetDiag()
|
||||
self.dtype = P.DType()
|
||||
|
|
|
@ -112,6 +112,7 @@ class Conv2dBnAct(Cell):
|
|||
activation=None,
|
||||
alpha=0.2,
|
||||
after_fake=True):
|
||||
"""Initialize Conv2dBnAct."""
|
||||
super(Conv2dBnAct, self).__init__()
|
||||
|
||||
self.conv = nn.Conv2d(in_channels,
|
||||
|
@ -206,6 +207,7 @@ class DenseBnAct(Cell):
|
|||
activation=None,
|
||||
alpha=0.2,
|
||||
after_fake=True):
|
||||
"""Initialize DenseBnAct."""
|
||||
super(DenseBnAct, self).__init__()
|
||||
self.dense = nn.Dense(
|
||||
in_channels,
|
||||
|
|
|
@ -72,7 +72,7 @@ def _get_prefix_and_index(cells):
|
|||
return prefix, index
|
||||
|
||||
|
||||
class _CellListBase():
|
||||
class _CellListBase:
|
||||
"""
|
||||
An interface for base the cell as list.
|
||||
|
||||
|
@ -84,6 +84,7 @@ class _CellListBase():
|
|||
by iterator or subscript , it will be interpreted as a list of cells.
|
||||
"""
|
||||
def __init__(self):
|
||||
"""Initialize _CellListBase."""
|
||||
self.__cell_as_list__ = True
|
||||
|
||||
@abstractmethod
|
||||
|
@ -133,6 +134,7 @@ class SequentialCell(Cell):
|
|||
[27. 27.]]]]
|
||||
"""
|
||||
def __init__(self, *args):
|
||||
"""Initialize SequentialCell."""
|
||||
super(SequentialCell, self).__init__()
|
||||
self._is_dynamic_name = []
|
||||
if len(args) == 1:
|
||||
|
@ -270,6 +272,7 @@ class CellList(_CellListBase, Cell):
|
|||
>
|
||||
"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Initialize CellList."""
|
||||
auto_prefix = kwargs["auto_prefix"] if "auto_prefix" in kwargs.keys() else True
|
||||
_CellListBase.__init__(self)
|
||||
Cell.__init__(self, auto_prefix)
|
||||
|
|
|
@ -47,6 +47,7 @@ class _Conv(Cell):
|
|||
bias_init,
|
||||
data_format='NCHW',
|
||||
transposed=False):
|
||||
"""Initialize _Conv."""
|
||||
super(_Conv, self).__init__()
|
||||
self.in_channels = Validator.check_positive_int(in_channels)
|
||||
self.out_channels = Validator.check_positive_int(out_channels)
|
||||
|
@ -207,8 +208,8 @@ class Conv2d(_Conv):
|
|||
|
||||
Examples:
|
||||
>>> net = nn.Conv2d(120, 240, 4, has_bias=False, weight_init='normal')
|
||||
>>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)
|
||||
>>> output = net(input).shape
|
||||
>>> x = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)
|
||||
>>> output = net(x).shape
|
||||
>>> print(output)
|
||||
(1, 240, 1024, 640)
|
||||
"""
|
||||
|
@ -227,6 +228,7 @@ class Conv2d(_Conv):
|
|||
weight_init='normal',
|
||||
bias_init='zeros',
|
||||
data_format='NCHW'):
|
||||
"""Initialize Conv2d."""
|
||||
kernel_size = twice(kernel_size)
|
||||
stride = twice(stride)
|
||||
self._dilation = dilation
|
||||
|
@ -372,8 +374,8 @@ class Conv1d(_Conv):
|
|||
|
||||
Examples:
|
||||
>>> net = nn.Conv1d(120, 240, 4, has_bias=False, weight_init='normal')
|
||||
>>> input = Tensor(np.ones([1, 120, 640]), mindspore.float32)
|
||||
>>> output = net(input).shape
|
||||
>>> x = Tensor(np.ones([1, 120, 640]), mindspore.float32)
|
||||
>>> output = net(x).shape
|
||||
>>> print(output)
|
||||
(1, 240, 640)
|
||||
"""
|
||||
|
@ -391,7 +393,7 @@ class Conv1d(_Conv):
|
|||
has_bias=False,
|
||||
weight_init='normal',
|
||||
bias_init='zeros'):
|
||||
|
||||
"""Initialize Conv1d."""
|
||||
Validator.check_value_type("kernel_size", kernel_size, [int], self.cls_name)
|
||||
Validator.check_value_type("stride", stride, [int], self.cls_name)
|
||||
Validator.check_value_type("padding", padding, [int], self.cls_name)
|
||||
|
@ -575,9 +577,9 @@ class Conv3d(_Conv):
|
|||
``Ascend``
|
||||
|
||||
Examples:
|
||||
>>> input = Tensor(np.ones([16, 3, 10, 32, 32]), mindspore.float32)
|
||||
>>> x = Tensor(np.ones([16, 3, 10, 32, 32]), mindspore.float32)
|
||||
>>> conv3d = nn.Conv3d(in_channels=3, out_channels=32, kernel_size=(4, 3, 3))
|
||||
>>> output = conv3d(input)
|
||||
>>> output = conv3d(x)
|
||||
>>> print(output.shape)
|
||||
(16, 32, 10, 32, 32)
|
||||
"""
|
||||
|
@ -596,6 +598,7 @@ class Conv3d(_Conv):
|
|||
weight_init='normal',
|
||||
bias_init='zeros',
|
||||
data_format='NCDHW'):
|
||||
"""Initialize Conv3d."""
|
||||
kernel_size = _check_3d_int_or_tuple("kernel_size", kernel_size, self.cls_name)
|
||||
stride = _check_3d_int_or_tuple("stride", stride, self.cls_name)
|
||||
dilation = _check_3d_int_or_tuple("dilation", dilation, self.cls_name)
|
||||
|
@ -746,10 +749,10 @@ class Conv3dTranspose(_Conv):
|
|||
ValueError: If `data_format` is not 'NCDHW'.
|
||||
|
||||
Examples:
|
||||
>>> input = Tensor(np.ones([32, 16, 10, 32, 32]), mindspore.float32)
|
||||
>>> x = Tensor(np.ones([32, 16, 10, 32, 32]), mindspore.float32)
|
||||
>>> conv3d_transpose = nn.Conv3dTranspose(in_channels=16, out_channels=3, kernel_size=(4, 6, 2),
|
||||
... pad_mode='pad')
|
||||
>>> output = conv3d_transpose(input)
|
||||
>>> output = conv3d_transpose(x)
|
||||
>>> print(output.shape)
|
||||
(32, 3, 13, 37, 33)
|
||||
"""
|
||||
|
@ -768,6 +771,7 @@ class Conv3dTranspose(_Conv):
|
|||
weight_init='normal',
|
||||
bias_init='zeros',
|
||||
data_format='NCDHW'):
|
||||
"""Initialize Conv3dTranspose."""
|
||||
kernel_size = _check_3d_int_or_tuple("kernel_size", kernel_size, self.cls_name)
|
||||
stride = _check_3d_int_or_tuple("stride", stride, self.cls_name)
|
||||
dilation = _check_3d_int_or_tuple("dilation", dilation, self.cls_name)
|
||||
|
@ -929,8 +933,8 @@ class Conv2dTranspose(_Conv):
|
|||
|
||||
Examples:
|
||||
>>> net = nn.Conv2dTranspose(3, 64, 4, has_bias=False, weight_init='normal', pad_mode='pad')
|
||||
>>> input = Tensor(np.ones([1, 3, 16, 50]), mindspore.float32)
|
||||
>>> output = net(input).shape
|
||||
>>> x = Tensor(np.ones([1, 3, 16, 50]), mindspore.float32)
|
||||
>>> output = net(x).shape
|
||||
>>> print(output)
|
||||
(1, 64, 19, 53)
|
||||
"""
|
||||
|
@ -947,6 +951,7 @@ class Conv2dTranspose(_Conv):
|
|||
has_bias=False,
|
||||
weight_init='normal',
|
||||
bias_init='zeros'):
|
||||
"""Initialize Conv2dTranspose."""
|
||||
kernel_size = twice(kernel_size)
|
||||
stride = twice(stride)
|
||||
dilation = twice(dilation)
|
||||
|
@ -1098,8 +1103,8 @@ class Conv1dTranspose(_Conv):
|
|||
|
||||
Examples:
|
||||
>>> net = nn.Conv1dTranspose(3, 64, 4, has_bias=False, weight_init='normal', pad_mode='pad')
|
||||
>>> input = Tensor(np.ones([1, 3, 50]), mindspore.float32)
|
||||
>>> output = net(input).shape
|
||||
>>> x = Tensor(np.ones([1, 3, 50]), mindspore.float32)
|
||||
>>> output = net(x).shape
|
||||
>>> print(output)
|
||||
(1, 64, 53)
|
||||
"""
|
||||
|
@ -1116,6 +1121,7 @@ class Conv1dTranspose(_Conv):
|
|||
has_bias=False,
|
||||
weight_init='normal',
|
||||
bias_init='zeros'):
|
||||
"""Initialize Conv1dTranspose."""
|
||||
Validator.check_value_type("kernel_size", kernel_size, [int], self.cls_name)
|
||||
Validator.check_value_type("stride", stride, [int], self.cls_name)
|
||||
Validator.check_value_type("padding", padding, [int], self.cls_name)
|
||||
|
|
|
@ -98,6 +98,7 @@ class Embedding(Cell):
|
|||
|
||||
def __init__(self, vocab_size, embedding_size, use_one_hot=False, embedding_table='normal',
|
||||
dtype=mstype.float32, padding_idx=None):
|
||||
"""Initialize Embedding."""
|
||||
super(Embedding, self).__init__()
|
||||
self.vocab_size = validator.check_value_type('vocab_size', vocab_size, [int], self.cls_name)
|
||||
self.embedding_size = validator.check_value_type('embedding_size', embedding_size, [int], self.cls_name)
|
||||
|
@ -223,6 +224,7 @@ class EmbeddingLookup(Cell):
|
|||
def __init__(self, vocab_size, embedding_size, param_init='normal',
|
||||
target='CPU', slice_mode='batch_slice', manual_shapes=None,
|
||||
max_norm=None, sparse=True, vocab_cache_size=0):
|
||||
"""Initialize EmbeddingLookup."""
|
||||
super(EmbeddingLookup, self).__init__()
|
||||
validator.check_value_type('sparse', sparse, [bool], self.cls_name)
|
||||
self.vocab_size = validator.check_positive_int(vocab_size, 'vocab_size')
|
||||
|
@ -416,11 +418,12 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
|
|||
['DEVICE', 'CPU']. Default: 'CPU'.
|
||||
slice_mode (str): The slicing way in semi_auto_parallel/auto_parallel. The value must get through
|
||||
nn.EmbeddingLookup. Default: nn.EmbeddingLookup.BATCH_SLICE.
|
||||
feature_num_list (tuple): The accompaniment array in field slice mode. This is unused currently.
|
||||
feature_num_list (tuple): The accompaniment array in field slice mode. This is unused currently. Default: None.
|
||||
max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32
|
||||
or None. Default: None
|
||||
sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true. Default: True.
|
||||
operator (str): The pooling method for the features in one field. Support 'SUM, 'MEAN' and 'MAX'
|
||||
operator (str): The pooling method for the features in one field. Support 'SUM', 'MEAN' and 'MAX'.
|
||||
Default: 'SUM'.
|
||||
|
||||
Inputs:
|
||||
- **input_indices** (Tensor) - The shape of tensor is :math:`(batch\_size, seq\_length)`.
|
||||
|
@ -464,6 +467,7 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
|
|||
|
||||
def __init__(self, vocab_size, embedding_size, field_size, param_init='normal', target='CPU',
|
||||
slice_mode='batch_slice', feature_num_list=None, max_norm=None, sparse=True, operator='SUM'):
|
||||
"""Initialize MultiFieldEmbeddingLookup."""
|
||||
super(MultiFieldEmbeddingLookup, self).__init__(vocab_size, embedding_size, param_init, target,
|
||||
slice_mode, feature_num_list, max_norm, sparse)
|
||||
self.field_size = validator.check_positive_int(field_size, 'field_size')
|
||||
|
|
|
@ -120,10 +120,10 @@ class LSTM(Cell):
|
|||
|
||||
Examples:
|
||||
>>> net = nn.LSTM(10, 16, 2, has_bias=True, batch_first=True, bidirectional=False)
|
||||
>>> input = Tensor(np.ones([3, 5, 10]).astype(np.float32))
|
||||
>>> x = Tensor(np.ones([3, 5, 10]).astype(np.float32))
|
||||
>>> h0 = Tensor(np.ones([1 * 2, 3, 16]).astype(np.float32))
|
||||
>>> c0 = Tensor(np.ones([1 * 2, 3, 16]).astype(np.float32))
|
||||
>>> output, (hn, cn) = net(input, (h0, c0))
|
||||
>>> output, (hn, cn) = net(x, (h0, c0))
|
||||
>>> print(output.shape)
|
||||
(3, 5, 16)
|
||||
"""
|
||||
|
@ -136,6 +136,7 @@ class LSTM(Cell):
|
|||
batch_first=False,
|
||||
dropout=0,
|
||||
bidirectional=False):
|
||||
"""Initialize LSTM."""
|
||||
super(LSTM, self).__init__()
|
||||
validator.check_value_type("batch_first", batch_first, [bool], self.cls_name)
|
||||
validator.check_positive_int(hidden_size, "hidden_size", self.cls_name)
|
||||
|
@ -360,11 +361,11 @@ class LSTMCell(Cell):
|
|||
|
||||
Examples:
|
||||
>>> net = nn.LSTMCell(10, 12, has_bias=True, batch_first=True, bidirectional=False)
|
||||
>>> input = Tensor(np.ones([3, 5, 10]).astype(np.float32))
|
||||
>>> x = Tensor(np.ones([3, 5, 10]).astype(np.float32))
|
||||
>>> h = Tensor(np.ones([1, 3, 12]).astype(np.float32))
|
||||
>>> c = Tensor(np.ones([1, 3, 12]).astype(np.float32))
|
||||
>>> w = Tensor(np.ones([1152, 1, 1]).astype(np.float32))
|
||||
>>> output, h, c, _, _ = net(input, h, c, w)
|
||||
>>> output, h, c, _, _ = net(x, h, c, w)
|
||||
>>> print(output.shape)
|
||||
(3, 5, 12)
|
||||
"""
|
||||
|
@ -376,6 +377,7 @@ class LSTMCell(Cell):
|
|||
batch_first=False,
|
||||
dropout=0,
|
||||
bidirectional=False):
|
||||
"""Initialize LSTMCell."""
|
||||
super(LSTMCell, self).__init__()
|
||||
self.batch_first = validator.check_value_type("batch_first", batch_first, [bool], self.cls_name)
|
||||
self.transpose = P.Transpose()
|
||||
|
|
|
@ -99,6 +99,7 @@ class ReduceLogSumExp(Cell):
|
|||
"""
|
||||
|
||||
def __init__(self, axis, keep_dims=False):
|
||||
"""Initialize ReduceLogSumExp."""
|
||||
super(ReduceLogSumExp, self).__init__()
|
||||
validator.check_value_type('axis', axis, [int, list, tuple], self.cls_name)
|
||||
validator.check_value_type('keep_dims', keep_dims, [bool], self.cls_name)
|
||||
|
@ -129,7 +130,7 @@ class Range(Cell):
|
|||
start (Union[int, float]): If `limit` is `None`, the value acts as limit in the range and first entry
|
||||
defaults to `0`. Otherwise, it acts as first entry in the range.
|
||||
limit (Union[int, float]): Acts as upper limit of sequence. If `None`, defaults to the value of `start`
|
||||
while set the first entry of the range to `0`. It can not be equal to `start`.
|
||||
while set the first entry of the range to `0`. It can not be equal to `start`. Default: None.
|
||||
delta (Union[int, float]): Increment of the range. It can not be equal to zero. Default: 1.
|
||||
|
||||
Outputs:
|
||||
|
@ -146,6 +147,7 @@ class Range(Cell):
|
|||
"""
|
||||
|
||||
def __init__(self, start, limit=None, delta=1):
|
||||
"""Initialize Range."""
|
||||
super(Range, self).__init__()
|
||||
if delta == 0:
|
||||
raise ValueError("The input of `delta` can not be equal to zero.")
|
||||
|
@ -211,6 +213,7 @@ class LGamma(Cell):
|
|||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize LGamma."""
|
||||
super(LGamma, self).__init__()
|
||||
# const numbers
|
||||
self.k_lanczos_gamma = 7
|
||||
|
@ -258,7 +261,7 @@ class LGamma(Cell):
|
|||
for i in range(8):
|
||||
product_ = k_lanczos_coefficients[i] / (z + i + 1)
|
||||
reflex_x = product_ + reflex_x
|
||||
return reflex_x
|
||||
return reflex_x
|
||||
reflex_x = _calculate_reflected_x(z, self.k_base_lanczos_coeff, self.k_lanczos_coefficients)
|
||||
|
||||
t = z + self.lanczos_gamma_plus_one_half
|
||||
|
@ -322,6 +325,7 @@ class DiGamma(Cell):
|
|||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize DiGamma."""
|
||||
super(DiGamma, self).__init__()
|
||||
# const numbers
|
||||
self.k_lanczos_gamma = 7
|
||||
|
@ -360,7 +364,7 @@ class DiGamma(Cell):
|
|||
for i in range(8):
|
||||
num = num - k_lanczos_coefficients[i] / ((z + i + 1) * (z + i + 1))
|
||||
denom = denom + k_lanczos_coefficients[i] / (z + i + 1)
|
||||
return num, denom
|
||||
return num, denom
|
||||
num, denom = _calculate_num_denom(z, self.k_base_lanczos_coeff, self.k_lanczos_coefficients)
|
||||
|
||||
t = z + self.lanczos_gamma_plus_one_half
|
||||
|
@ -586,6 +590,7 @@ class IGamma(Cell):
|
|||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize IGamma."""
|
||||
super(IGamma, self).__init__()
|
||||
# const numbers
|
||||
# If more data types are supported, this float max value need to be selected.
|
||||
|
@ -674,6 +679,7 @@ class LBeta(Cell):
|
|||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize LBeta."""
|
||||
super(LBeta, self).__init__()
|
||||
# const numbers
|
||||
self.log_2pi = np.log(2 * np.pi)
|
||||
|
@ -855,6 +861,7 @@ class MatMul(Cell):
|
|||
|
||||
@deprecated('1.2', 'ops.matmul', False)
|
||||
def __init__(self, transpose_x1=False, transpose_x2=False):
|
||||
"""Initialize MatMul."""
|
||||
super(MatMul, self).__init__()
|
||||
|
||||
validator.check_value_type('transpose_x1', transpose_x1, [bool], self.cls_name)
|
||||
|
@ -906,9 +913,9 @@ class Moments(Cell):
|
|||
Calculates the mean and variance of `x`.
|
||||
|
||||
Args:
|
||||
axis (Union[int, tuple(int)]): Calculates the mean and variance along the specified axis. Default: ().
|
||||
axis (Union[int, tuple(int)]): Calculates the mean and variance along the specified axis. Default: None.
|
||||
keep_dims (bool): If true, The dimension of mean and variance are identical with input's.
|
||||
If false, don't keep these dimensions. Default: False.
|
||||
If false, don't keep these dimensions. Default: None.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The tensor to be calculated. Only float16 and float32 are supported.
|
||||
|
@ -938,6 +945,7 @@ class Moments(Cell):
|
|||
"""
|
||||
|
||||
def __init__(self, axis=None, keep_dims=None):
|
||||
"""Initialize Moments."""
|
||||
super(Moments, self).__init__()
|
||||
if axis is None:
|
||||
axis = ()
|
||||
|
@ -997,6 +1005,7 @@ class MatInverse(Cell):
|
|||
[2.1111116 -0.5555557 0.11111111]]
|
||||
"""
|
||||
def __init__(self):
|
||||
"""Initialize MatInverse."""
|
||||
super(MatInverse, self).__init__()
|
||||
self.dtype = P.DType()
|
||||
self.choleskytrsm = P.CholeskyTrsm()
|
||||
|
@ -1035,6 +1044,7 @@ class MatDet(Cell):
|
|||
35.999996
|
||||
"""
|
||||
def __init__(self):
|
||||
"""Initialize MatDet."""
|
||||
super(MatDet, self).__init__()
|
||||
self.dtype = P.DType()
|
||||
self.cholesky = P.Cholesky()
|
||||
|
|
|
@ -58,6 +58,7 @@ class _BatchNorm(Cell):
|
|||
process_groups=0,
|
||||
input_dims='2d',
|
||||
data_format='NCHW'):
|
||||
"""Initialize _BatchNorm."""
|
||||
super(_BatchNorm, self).__init__()
|
||||
validator.check_value_type('num_features', num_features, [int], self.cls_name)
|
||||
if num_features < 1:
|
||||
|
@ -339,6 +340,7 @@ class BatchNorm1d(_BatchNorm):
|
|||
moving_mean_init='zeros',
|
||||
moving_var_init='ones',
|
||||
use_batch_statistics=None):
|
||||
"""Initialize BatchNorm1d."""
|
||||
super(BatchNorm1d, self).__init__(num_features,
|
||||
eps,
|
||||
momentum,
|
||||
|
@ -447,6 +449,7 @@ class BatchNorm2d(_BatchNorm):
|
|||
moving_var_init='ones',
|
||||
use_batch_statistics=None,
|
||||
data_format='NCHW'):
|
||||
"""Initialize BatchNorm2d."""
|
||||
super(BatchNorm2d, self).__init__(num_features,
|
||||
eps,
|
||||
momentum,
|
||||
|
@ -546,6 +549,7 @@ class BatchNorm3d(Cell):
|
|||
moving_var_init='ones',
|
||||
use_batch_statistics=None,
|
||||
data_format='NCDHW'):
|
||||
"""Initialize BatchNorm3d."""
|
||||
super(BatchNorm3d, self).__init__()
|
||||
self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.cls_name)
|
||||
self.reshape = P.Reshape()
|
||||
|
@ -664,6 +668,7 @@ class GlobalBatchNorm(_BatchNorm):
|
|||
moving_var_init='ones',
|
||||
use_batch_statistics=None,
|
||||
device_num_each_group=2):
|
||||
"""Initialize GlobalBatchNorm."""
|
||||
super(GlobalBatchNorm, self).__init__(num_features,
|
||||
eps,
|
||||
momentum,
|
||||
|
@ -782,6 +787,7 @@ class SyncBatchNorm(_BatchNorm):
|
|||
moving_var_init='ones',
|
||||
use_batch_statistics=None,
|
||||
process_groups=None):
|
||||
"""Initialize SyncBatchNorm."""
|
||||
super(SyncBatchNorm, self).__init__(num_features,
|
||||
eps,
|
||||
momentum,
|
||||
|
@ -861,6 +867,7 @@ class LayerNorm(Cell):
|
|||
beta_init='zeros',
|
||||
epsilon=1e-7
|
||||
):
|
||||
"""Initialize LayerNorm."""
|
||||
super(LayerNorm, self).__init__()
|
||||
if not isinstance(normalized_shape, (tuple, list)):
|
||||
raise TypeError("The type of 'normalized_shape' should be tuple[int] or list[int], but '{}' type is {}."
|
||||
|
@ -965,6 +972,7 @@ class InstanceNorm2d(Cell):
|
|||
affine=True,
|
||||
gamma_init='ones',
|
||||
beta_init='zeros'):
|
||||
"""Initialize InstanceNorm2d."""
|
||||
super(InstanceNorm2d, self).__init__()
|
||||
validator.check_value_type('num_features', num_features, [int], self.cls_name)
|
||||
validator.check_value_type('eps', eps, [float], self.cls_name)
|
||||
|
@ -1073,6 +1081,7 @@ class GroupNorm(Cell):
|
|||
"""
|
||||
|
||||
def __init__(self, num_groups, num_channels, eps=1e-05, affine=True, gamma_init='ones', beta_init='zeros'):
|
||||
"""Initialize GroupNorm."""
|
||||
super(GroupNorm, self).__init__()
|
||||
self.num_groups = validator.check_positive_int(num_groups)
|
||||
self.num_channels = validator.check_positive_int(num_channels)
|
||||
|
|
|
@ -27,6 +27,7 @@ class _PoolNd(Cell):
|
|||
"""N-D AvgPool"""
|
||||
|
||||
def __init__(self, kernel_size, stride, pad_mode, data_format="NCHW"):
|
||||
"""Initialize _PoolNd."""
|
||||
super(_PoolNd, self).__init__()
|
||||
self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.cls_name)
|
||||
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name)
|
||||
|
@ -128,6 +129,7 @@ class MaxPool2d(_PoolNd):
|
|||
"""
|
||||
|
||||
def __init__(self, kernel_size=1, stride=1, pad_mode="valid", data_format="NCHW"):
|
||||
"""Initialize MaxPool2d."""
|
||||
super(MaxPool2d, self).__init__(kernel_size, stride, pad_mode, data_format)
|
||||
self.max_pool = P.MaxPool(kernel_size=self.kernel_size,
|
||||
strides=self.stride,
|
||||
|
@ -196,6 +198,7 @@ class MaxPool1d(_PoolNd):
|
|||
"""
|
||||
|
||||
def __init__(self, kernel_size=1, stride=1, pad_mode="valid"):
|
||||
"""Initialize MaxPool1d."""
|
||||
super(MaxPool1d, self).__init__(kernel_size, stride, pad_mode)
|
||||
validator.check_value_type('kernel_size', kernel_size, [int], self.cls_name)
|
||||
validator.check_value_type('stride', stride, [int], self.cls_name)
|
||||
|
@ -288,6 +291,7 @@ class AvgPool2d(_PoolNd):
|
|||
stride=1,
|
||||
pad_mode="valid",
|
||||
data_format="NCHW"):
|
||||
"""Initialize AvgPool2d."""
|
||||
super(AvgPool2d, self).__init__(kernel_size, stride, pad_mode, data_format)
|
||||
self.avg_pool = P.AvgPool(kernel_size=self.kernel_size,
|
||||
strides=self.stride,
|
||||
|
@ -359,6 +363,7 @@ class AvgPool1d(_PoolNd):
|
|||
kernel_size=1,
|
||||
stride=1,
|
||||
pad_mode="valid"):
|
||||
"""Initialize AvgPool1d."""
|
||||
validator.check_value_type('kernel_size', kernel_size, [int], self.cls_name)
|
||||
validator.check_value_type('stride', stride, [int], self.cls_name)
|
||||
self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.cls_name)
|
||||
|
|
|
@ -156,6 +156,7 @@ class _Observer(Cell):
|
|||
"""
|
||||
|
||||
def __init__(self, quant_dtype):
|
||||
"""Initialize _Observer."""
|
||||
super(_Observer, self).__init__()
|
||||
self.quant_dtype = quant_dtype
|
||||
|
||||
|
@ -204,6 +205,7 @@ class UniformQuantObserver(_Observer):
|
|||
|
||||
def __init__(self, quant_dtype=QuantDtype.INT8, per_channel=False, symmetric=False, narrow_range=False,
|
||||
num_channels=1):
|
||||
"""Initialize UniformQuantObserver."""
|
||||
super(UniformQuantObserver, self).__init__(quant_dtype)
|
||||
self.per_channel = per_channel
|
||||
self.symmetric = symmetric
|
||||
|
@ -1109,6 +1111,7 @@ class Conv2dBnWithoutFoldQuant(Cell):
|
|||
bias_init='zeros',
|
||||
quant_config=quant_config_default,
|
||||
quant_dtype=QuantDtype.INT8):
|
||||
"""Initialize Conv2dBnWithoutFoldQuant."""
|
||||
super(Conv2dBnWithoutFoldQuant, self).__init__()
|
||||
self.in_channels = Validator.check_positive_int(in_channels)
|
||||
self.out_channels = Validator.check_positive_int(out_channels)
|
||||
|
@ -1249,6 +1252,7 @@ class Conv2dQuant(Cell):
|
|||
bias_init='zeros',
|
||||
quant_config=quant_config_default,
|
||||
quant_dtype=QuantDtype.INT8):
|
||||
"""Initialize Conv2dQuant."""
|
||||
super(Conv2dQuant, self).__init__()
|
||||
self.in_channels = Validator.check_positive_int(in_channels)
|
||||
self.out_channels = Validator.check_positive_int(out_channels)
|
||||
|
@ -1380,6 +1384,7 @@ class DenseQuant(Cell):
|
|||
activation=None,
|
||||
quant_config=quant_config_default,
|
||||
quant_dtype=QuantDtype.INT8):
|
||||
"""Initialize DenseQuant."""
|
||||
super(DenseQuant, self).__init__()
|
||||
self.in_channels = Validator.check_positive_int(in_channels)
|
||||
self.out_channels = Validator.check_positive_int(out_channels)
|
||||
|
@ -1495,6 +1500,7 @@ class ActQuant(_QuantActivation):
|
|||
fake_before=False,
|
||||
quant_config=quant_config_default,
|
||||
quant_dtype=QuantDtype.INT8):
|
||||
"""Initialize ActQuant."""
|
||||
super(ActQuant, self).__init__()
|
||||
act_class = activation.__class__
|
||||
act_list = [nn.ReLU, nn.ReLU6]
|
||||
|
@ -1578,6 +1584,7 @@ class TensorAddQuant(Cell):
|
|||
ema_decay=0.999,
|
||||
quant_config=quant_config_default,
|
||||
quant_dtype=QuantDtype.INT8):
|
||||
"""Initialize TensorAddQuant."""
|
||||
super(TensorAddQuant, self).__init__()
|
||||
self.fake_quant_act = quant_config.activation(min_init=-6,
|
||||
max_init=6,
|
||||
|
@ -1638,6 +1645,7 @@ class MulQuant(Cell):
|
|||
ema_decay=0.999,
|
||||
quant_config=quant_config_default,
|
||||
quant_dtype=QuantDtype.INT8):
|
||||
"""Initialize MulQuant."""
|
||||
super(MulQuant, self).__init__()
|
||||
self.fake_quant_act = quant_config.activation(min_init=-6,
|
||||
max_init=6,
|
||||
|
|
|
@ -55,19 +55,19 @@ class DenseThor(Cell):
|
|||
activation (str): activate function applied to the output of the fully connected layer, eg. 'ReLU'.
|
||||
Default: None.
|
||||
|
||||
Raises:
|
||||
ValueError: If weight_init shape or bias_init shape is incorrect.
|
||||
|
||||
Inputs:
|
||||
- **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`.
|
||||
|
||||
Outputs:
|
||||
Tensor of shape :math:`(N, out\_channels)`.
|
||||
|
||||
Raises:
|
||||
ValueError: If the shape of `weight_init` or `bias_init` is incorrect.
|
||||
|
||||
Examples:
|
||||
>>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
|
||||
>>> x = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
|
||||
>>> net = nn.DenseThor(3, 4)
|
||||
>>> net(input)
|
||||
>>> net(x)
|
||||
[[ 2.5246444 2.2738023 0.5711005 -3.9399147 ]
|
||||
[ 1.0739875 4.0155234 0.94188046 -5.459526 ]]
|
||||
"""
|
||||
|
@ -78,6 +78,7 @@ class DenseThor(Cell):
|
|||
bias_init='zeros',
|
||||
has_bias=True,
|
||||
activation=None):
|
||||
"""Initialize DenseThor."""
|
||||
super(DenseThor, self).__init__()
|
||||
self.thor = True
|
||||
self.in_channels = Validator.check_positive_int(in_channels)
|
||||
|
@ -117,7 +118,6 @@ class DenseThor(Cell):
|
|||
self.cube_matmul = P.MatMul(transpose_a=True)
|
||||
self.getG = P.InsertGradientOf(self.save_gradient)
|
||||
|
||||
|
||||
def _process_ascend_dense_thor(self, out_channels):
|
||||
"""process ascend dense thor"""
|
||||
if out_channels == 1001:
|
||||
|
@ -139,7 +139,6 @@ class DenseThor(Cell):
|
|||
self.cast = P.Cast()
|
||||
self.is_nsp_layer = (out_channels == 2)
|
||||
|
||||
|
||||
def save_gradient(self, dout):
|
||||
"""
|
||||
this function only for thor optimizer
|
||||
|
@ -201,6 +200,7 @@ class _ConvThor(Cell):
|
|||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride, pad_mode,
|
||||
padding, dilation, group, has_bias, weight_init, bias_init, transposed=False):
|
||||
"""Initialize _ConvThor."""
|
||||
super(_ConvThor, self).__init__()
|
||||
self.in_channels = Validator.check_positive_int(in_channels)
|
||||
self.out_channels = Validator.check_positive_int(out_channels)
|
||||
|
@ -353,14 +353,15 @@ class Conv2dThor(_ConvThor):
|
|||
|
||||
Examples:
|
||||
>>> net = nn.Conv2dThor(120, 240, 4, has_bias=False, weight_init='normal')
|
||||
>>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)
|
||||
>>> print(net(input).shape)
|
||||
>>> x = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)
|
||||
>>> print(net(x).shape)
|
||||
(1, 240, 1024, 640)
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
||||
pad_mode='same', padding=0, dilation=1, group=1, has_bias=False,
|
||||
weight_init='normal', bias_init='zeros'):
|
||||
"""Initialize Conv2dThor."""
|
||||
kernel_size = twice(kernel_size)
|
||||
stride = twice(stride)
|
||||
self._dilation = dilation
|
||||
|
@ -448,7 +449,6 @@ class Conv2dThor(_ConvThor):
|
|||
self.weight_init.shape = weight_shape
|
||||
self.weight = Parameter(initializer(self.weight_init, weight_shape), name='weight')
|
||||
|
||||
|
||||
def save_gradient(self, dout):
|
||||
"""save_gradient"""
|
||||
out = dout
|
||||
|
@ -474,7 +474,6 @@ class Conv2dThor(_ConvThor):
|
|||
self.matrix_g_cov = matrix_g
|
||||
return out
|
||||
|
||||
|
||||
def construct(self, x):
|
||||
if self.thor:
|
||||
matrix_a = self.img2col(x)
|
||||
|
@ -563,6 +562,7 @@ class EmbeddingThor(Cell):
|
|||
|
||||
def __init__(self, vocab_size, embedding_size, use_one_hot=False, embedding_table='normal',
|
||||
dtype=mstype.float32, padding_idx=None):
|
||||
"""Initialize EmbeddingThor."""
|
||||
super(EmbeddingThor, self).__init__()
|
||||
self.vocab_size = Validator.check_value_type('vocab_size', vocab_size, [int], self.cls_name)
|
||||
self.embedding_size = Validator.check_value_type('embedding_size', embedding_size, [int], self.cls_name)
|
||||
|
@ -602,7 +602,6 @@ class EmbeddingThor(Cell):
|
|||
self.cube_matmul = P.MatMul(transpose_a=True)
|
||||
self.mul = P.Mul()
|
||||
|
||||
|
||||
def save_gradient(self, dout):
|
||||
"""
|
||||
this function only for thor optimizer
|
||||
|
@ -634,11 +633,9 @@ class EmbeddingThor(Cell):
|
|||
else:
|
||||
output_for_reshape = self.gather(self.embedding_table, flat_ids, 0)
|
||||
|
||||
|
||||
output = self.reshape(output_for_reshape, out_shape)
|
||||
return output
|
||||
|
||||
|
||||
def extend_repr(self):
|
||||
s = 'vocab_size={}, embedding_size={}, use_one_hot={}, embedding_table={}, dtype={}, padding_idx={}'.format(
|
||||
self.vocab_size, self.embedding_size, self.use_one_hot, self.embedding_table, self.dtype, self.padding_idx)
|
||||
|
|
|
@ -85,15 +85,16 @@ class TimeDistributed(Cell):
|
|||
TypeError: If layer is not a Cell or Primitive.
|
||||
|
||||
Examples:
|
||||
>>> input = Tensor(np.random.random([32, 10, 3]), mindspore.float32)
|
||||
>>> x = Tensor(np.random.random([32, 10, 3]), mindspore.float32)
|
||||
>>> dense = nn.Dense(3, 6)
|
||||
>>> net = nn.TimeDistributed(dense, time_axis=1, reshape_with_axis=0)
|
||||
>>> output = net(input)
|
||||
>>> output = net(x)
|
||||
>>> print(output.shape)
|
||||
(32, 10, 6)
|
||||
"""
|
||||
|
||||
def __init__(self, layer, time_axis, reshape_with_axis=None):
|
||||
"""Initialize TimeDistributed."""
|
||||
if not isinstance(layer, (Cell, Primitive)):
|
||||
raise TypeError("Please initialize TimeDistributed with mindspore.nn.Cell or "
|
||||
"mindspore.ops.Primitive instance. You passed: {input}".format(input=layer))
|
||||
|
|
|
@ -46,6 +46,7 @@ class Loss(Cell):
|
|||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
def __init__(self, reduction='mean'):
|
||||
"""Initialize Loss."""
|
||||
super(Loss, self).__init__()
|
||||
|
||||
if reduction not in ('mean', 'sum', 'none'):
|
||||
|
@ -96,6 +97,7 @@ class _Loss(Loss):
|
|||
Base class for other losses.
|
||||
"""
|
||||
def __init__(self, reduction='mean'):
|
||||
"""Initialize _Loss."""
|
||||
log.warning("'_Loss' is deprecated from version 1.3 and "
|
||||
"will be removed in a future version, use 'Loss' instead.")
|
||||
super(_Loss, self).__init__()
|
||||
|
@ -150,6 +152,7 @@ class L1Loss(Loss):
|
|||
0.33333334
|
||||
"""
|
||||
def __init__(self, reduction='mean'):
|
||||
"""Initialize L1Loss."""
|
||||
super(L1Loss, self).__init__(reduction)
|
||||
self.abs = P.Abs()
|
||||
|
||||
|
@ -238,6 +241,7 @@ class RMSELoss(Loss):
|
|||
0.57735026
|
||||
"""
|
||||
def __init__(self):
|
||||
"""Initialize RMSELoss."""
|
||||
super(RMSELoss, self).__init__()
|
||||
self.MSELoss = MSELoss()
|
||||
|
||||
|
@ -285,6 +289,7 @@ class MAELoss(Loss):
|
|||
0.33333334
|
||||
"""
|
||||
def __init__(self, reduction='mean'):
|
||||
"""Initialize MAELoss."""
|
||||
super(MAELoss, self).__init__(reduction)
|
||||
self.abs = P.Abs()
|
||||
|
||||
|
@ -347,6 +352,7 @@ class SmoothL1Loss(Loss):
|
|||
[0. 0. 0.5]
|
||||
"""
|
||||
def __init__(self, beta=1.0):
|
||||
"""Initialize SmoothL1Loss."""
|
||||
super(SmoothL1Loss, self).__init__()
|
||||
self.beta = beta
|
||||
self.smooth_l1_loss = P.SmoothL1Loss(self.beta)
|
||||
|
@ -418,6 +424,7 @@ class SoftmaxCrossEntropyWithLogits(Loss):
|
|||
def __init__(self,
|
||||
sparse=False,
|
||||
reduction='none'):
|
||||
"""Initialize SoftmaxCrossEntropyWithLogits."""
|
||||
super(SoftmaxCrossEntropyWithLogits, self).__init__(reduction)
|
||||
self.sparse = validator.check_bool(sparse, "sparse")
|
||||
self.reduction = reduction
|
||||
|
@ -481,6 +488,7 @@ class DiceLoss(Loss):
|
|||
0.38596618
|
||||
"""
|
||||
def __init__(self, smooth=1e-5):
|
||||
"""Initialize DiceLoss."""
|
||||
super(DiceLoss, self).__init__()
|
||||
self.smooth = validator.check_positive_float(smooth, "smooth")
|
||||
self.reshape = P.Reshape()
|
||||
|
@ -559,6 +567,7 @@ class MultiClassDiceLoss(Loss):
|
|||
0.3283009
|
||||
"""
|
||||
def __init__(self, weights=None, ignore_indiex=None, activation="softmax"):
|
||||
"""Initialize MultiClassDiceLoss."""
|
||||
super(MultiClassDiceLoss, self).__init__()
|
||||
activation_list = ['softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'sigmoid']
|
||||
|
||||
|
@ -604,7 +613,7 @@ class SampledSoftmaxLoss(Loss):
|
|||
Args:
|
||||
num_sampled (int): The number of classes to randomly sample per batch.
|
||||
num_classes (int): The number of possible classes.
|
||||
num_true (int): The number of target classes per training example.
|
||||
num_true (int): The number of target classes per training example. Default: 1.
|
||||
sampled_values (Union[list, tuple]): List or tuple of (`sampled_candidates`, `true_expected_count`,
|
||||
`sampled_expected_count`) returned by a `*CandidateSampler` function.
|
||||
Default to None, `UniformCandidateSampler` is applied.
|
||||
|
@ -650,6 +659,7 @@ class SampledSoftmaxLoss(Loss):
|
|||
def __init__(self, num_sampled, num_classes, num_true=1,
|
||||
sampled_values=None, remove_accidental_hits=True, seed=0,
|
||||
reduction='none'):
|
||||
"""Initialize SampledSoftmaxLoss."""
|
||||
super(SampledSoftmaxLoss, self).__init__(reduction)
|
||||
|
||||
if num_true < 1:
|
||||
|
@ -877,6 +887,7 @@ class BCELoss(Loss):
|
|||
"""
|
||||
|
||||
def __init__(self, weight=None, reduction='none'):
|
||||
"""Initialize BCELoss."""
|
||||
super(BCELoss, self).__init__()
|
||||
self.binary_cross_entropy = P.BinaryCrossEntropy(reduction=reduction)
|
||||
self.weight_one = weight is None
|
||||
|
@ -946,6 +957,7 @@ class CosineEmbeddingLoss(Loss):
|
|||
0.0003426075
|
||||
"""
|
||||
def __init__(self, margin=0.0, reduction="mean"):
|
||||
"""Initialize CosineEmbeddingLoss."""
|
||||
super(CosineEmbeddingLoss, self).__init__(reduction)
|
||||
self.reduce_sum = P.ReduceSum()
|
||||
self.maximum = P.Maximum()
|
||||
|
@ -1035,6 +1047,7 @@ class BCEWithLogitsLoss(Loss):
|
|||
"""
|
||||
|
||||
def __init__(self, reduction='mean', weight=None, pos_weight=None):
|
||||
"""Initialize BCEWithLogitsLoss."""
|
||||
super(BCEWithLogitsLoss, self).__init__()
|
||||
self.bce_with_logits_loss = P.BCEWithLogitsLoss(reduction=reduction)
|
||||
if isinstance(weight, Parameter):
|
||||
|
@ -1139,6 +1152,7 @@ class FocalLoss(Loss):
|
|||
"""
|
||||
|
||||
def __init__(self, weight=None, gamma=2.0, reduction='mean'):
|
||||
"""Initialize FocalLoss."""
|
||||
super(FocalLoss, self).__init__(reduction=reduction)
|
||||
|
||||
self.gamma = validator.check_value_type("gamma", gamma, [float])
|
||||
|
|
|
@ -57,6 +57,7 @@ class SparseToDense(Cell):
|
|||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize SparseToDense."""
|
||||
super(SparseToDense, self).__init__()
|
||||
self.sparse_to_dense = P.SparseToDense()
|
||||
|
||||
|
|
|
@ -29,7 +29,7 @@ from ..primitive import constexpr
|
|||
from ... import context
|
||||
from ...common import dtype as mstype
|
||||
from ...common.tensor import RowTensor
|
||||
from .._utils.utils import range_op, get_1d_shape
|
||||
from .._utils.utils import range_op, get_1d_shape, generate_shape_index
|
||||
|
||||
reduce_sum = P.ReduceSum()
|
||||
unsorted_segment_sum = P.UnsortedSegmentSum()
|
||||
|
@ -358,18 +358,6 @@ def get_bprop_slice(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@constexpr
|
||||
def _generate_shape_index(out_shape, indices_shape, axis):
|
||||
out_rank = len(out_shape)
|
||||
ind_rank = len(indices_shape)
|
||||
if axis < 0:
|
||||
axis += out_rank - ind_rank + 1
|
||||
perm_part1 = tuple(range(axis, axis + ind_rank))
|
||||
index = tuple(range(out_rank))
|
||||
perm = perm_part1 + index[:axis] + index[axis + ind_rank:]
|
||||
return perm
|
||||
|
||||
|
||||
@constexpr
|
||||
def _generate_inverse_index(x_shape, axis):
|
||||
x_rank = len(x_shape)
|
||||
|
@ -409,7 +397,7 @@ def get_bprop_gather_v2(self):
|
|||
out_shp = shape_op(dout)
|
||||
ind_shp = shape_op(indices)
|
||||
# Example: out_shape:(3,2,3) axis 1 -> (1,0,2)
|
||||
perm_1 = _generate_shape_index(out_shp, ind_shp, axis)
|
||||
perm_1 = generate_shape_index(out_shp, ind_shp, axis)
|
||||
values_transpose = transpose(dout, perm_1)
|
||||
if -1 in shape_op(x):
|
||||
params_grad = unsorted_segment_sum(values_transpose, indices, dyn_shape_op(x)[axis])
|
||||
|
@ -488,7 +476,7 @@ def get_bprop_sparse_gather_v2(self):
|
|||
out_shp = shape_op(dout)
|
||||
ind_shp = shape_op(indices)
|
||||
# Example: out_shape:(3,2,3) axis 1 -> (1,0,2)
|
||||
perm_1 = _generate_shape_index(out_shp, ind_shp, axis)
|
||||
perm_1 = generate_shape_index(out_shp, ind_shp, axis)
|
||||
values_transpose = transpose(dout, perm_1)
|
||||
params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis])
|
||||
# Example: out_shape:(3,2,3) axis 2 -> (1,2,0)
|
||||
|
@ -680,7 +668,7 @@ def get_bprop_oneslike(self):
|
|||
|
||||
@bprop_getters.register(P.ZerosLike)
|
||||
def get_bprop_zeroslike(self):
|
||||
"""Generate bprop for OnesLike"""
|
||||
"""Generate bprop for ZerosLike"""
|
||||
|
||||
def bprop(x, out, dout):
|
||||
return (zeros_like(x),)
|
||||
|
|
|
@ -26,17 +26,9 @@ get_dtype = P.DType()
|
|||
|
||||
|
||||
@bprops.register("MaximumGrad")
|
||||
def bprop_maximum_grad_grad(x, y, z, out, dout):
|
||||
"""Backpropagator for primitive `MaximumGrad`."""
|
||||
out0 = F.cast(out[0] != 0, get_dtype(dout[0]))
|
||||
out1 = F.cast(out[1] != 0, get_dtype(dout[1]))
|
||||
dz = out0 * dout[0] + out1 * dout[1]
|
||||
return F.zeros_like(x), F.zeros_like(y), dz
|
||||
|
||||
|
||||
@bprops.register("MinimumGrad")
|
||||
def bprop_minimum_grad_grad(x, y, z, out, dout):
|
||||
"""Backpropagator for primitive `MinimumGrad`."""
|
||||
def bprop_max_and_minimum_grad_grad(x, y, z, out, dout):
|
||||
"""Backpropagator for primitive `MaximumGrad` and `MinimumGrad`."""
|
||||
out0 = F.cast(out[0] != 0, get_dtype(dout[0]))
|
||||
out1 = F.cast(out[1] != 0, get_dtype(dout[1]))
|
||||
dz = out0 * dout[0] + out1 * dout[1]
|
||||
|
|
|
@ -298,7 +298,6 @@ def get_bprop_floor(self):
|
|||
bc_x = fill_(dtype_(x), shape_(x), 0.)
|
||||
return (bc_x,)
|
||||
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
|
@ -420,7 +419,7 @@ def get_bprop_xlogy(self):
|
|||
|
||||
@bprop_getters.register(P.SquareSumAll)
|
||||
def get_bprop_square_sum_all(self):
|
||||
"""Grad definition for `Square` operation."""
|
||||
"""Grad definition for `SquareSumAll` operation."""
|
||||
mul_func = P.Mul()
|
||||
fill_func = P.Fill()
|
||||
dtype = P.DType()
|
||||
|
|
|
@ -622,6 +622,7 @@ def get_bprop_tanh_grad(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.Gelu)
|
||||
@bprop_getters.register(P.GeLU)
|
||||
def get_bprop_gelu(self):
|
||||
"""Grad definition for `GeLU` operation."""
|
||||
|
@ -634,18 +635,6 @@ def get_bprop_gelu(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.Gelu)
|
||||
def get_bprop_gelu_2(self):
|
||||
"""Grad definition for `GeLU` operation."""
|
||||
input_grad = G.GeLUGrad()
|
||||
|
||||
def bprop(x, out, dout):
|
||||
dx = input_grad(dout, x, out)
|
||||
return (dx,)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.FastGeLU)
|
||||
def get_bprop_fast_gelu(self):
|
||||
"""Grad definition for `FastGeLU` operation."""
|
||||
|
@ -1156,28 +1145,9 @@ def get_bprop_dropout(self):
|
|||
|
||||
|
||||
@bprop_getters.register(P.Dropout2D)
|
||||
def get_bprop_dropout2d(self):
|
||||
"""Grad definition for `Dropout2D` operation."""
|
||||
dtype = P.DType()
|
||||
cast = P.Cast()
|
||||
mul = P.Mul()
|
||||
keep_prob = self.keep_prob
|
||||
|
||||
def bprop(x, out, dout):
|
||||
_, mask = dout
|
||||
y = cast(mask, mstype.float32)
|
||||
if keep_prob != 0:
|
||||
y = y * (1 / keep_prob)
|
||||
y = mul(x, y)
|
||||
y = cast(y, dtype(x))
|
||||
return (y,)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.Dropout3D)
|
||||
def get_bprop_dropout3d(self):
|
||||
"""Grad definition for `Dropout3D` operation."""
|
||||
"""Grad definition for `Dropout2D` and `Dropout3D` operation."""
|
||||
dtype = P.DType()
|
||||
cast = P.Cast()
|
||||
mul = P.Mul()
|
||||
|
|
|
@ -28,6 +28,7 @@ from .grad_base import bprop_getters
|
|||
@bprop_getters.register(P.Assign)
|
||||
def get_bprop_assign(self):
|
||||
"""Generate bprop for Assign"""
|
||||
|
||||
def bprop(x, y, out, dout):
|
||||
return (dout, zeros_like(y))
|
||||
return bprop
|
||||
|
@ -88,6 +89,7 @@ def get_bprop_sync_batch_norm(self):
|
|||
@bprop_getters.register(inner.GpuConvertToDynamicShape)
|
||||
def get_bprop_gpu_convert_to_dynamic_shape(self):
|
||||
"""Get backprop for GpuConvertToDynamicShape."""
|
||||
|
||||
def bprop(x, out, dout):
|
||||
return (dout,)
|
||||
return bprop
|
||||
|
|
|
@ -139,6 +139,7 @@ def get_bprop_BatchNormFold(self):
|
|||
@bprop_getters.register(P.BNTrainingReduce)
|
||||
def get_bprop_BNTrainingReduce(self):
|
||||
"""Generate bprop for BNTrainingReduce for Ascend"""
|
||||
|
||||
def bprop(x, out, dout):
|
||||
return (zeros_like(x),)
|
||||
|
||||
|
@ -199,6 +200,7 @@ def get_bprop_acts_ulq(self):
|
|||
@bprop_getters.register(Q.WtsARQ)
|
||||
def get_bprop_wts_arq(self):
|
||||
"""Grad definition for 'WtsArq' operation"""
|
||||
|
||||
def bprop(w, w_min, w_max, out, dout):
|
||||
return (dout, zeros_like(w_min), zeros_like(w_max))
|
||||
|
||||
|
|
|
@ -107,3 +107,15 @@ def get_1d_shape(in_shape):
|
|||
for i in in_shape:
|
||||
out_shape *= i
|
||||
return (out_shape,)
|
||||
|
||||
|
||||
@constexpr
|
||||
def generate_shape_index(out_shape, indices_shape, axis):
|
||||
out_rank = len(out_shape)
|
||||
ind_rank = len(indices_shape)
|
||||
if axis < 0:
|
||||
axis += out_rank - ind_rank + 1
|
||||
perm_part1 = tuple(range(axis, axis + ind_rank))
|
||||
index = tuple(range(out_rank))
|
||||
perm = perm_part1 + index[:axis] + index[axis + ind_rank:]
|
||||
return perm
|
||||
|
|
|
@ -316,6 +316,7 @@ class GradOperation(GradOperation_):
|
|||
"""
|
||||
|
||||
def __init__(self, get_all=False, get_by_list=False, sens_param=False):
|
||||
"""Initialize GradOperation."""
|
||||
if not isinstance(get_all, bool):
|
||||
raise TypeError(f'get_all should be bool, but got {type(get_all)}')
|
||||
if not isinstance(get_by_list, bool):
|
||||
|
@ -425,6 +426,7 @@ class MultitypeFuncGraph(MultitypeFuncGraph_):
|
|||
"""
|
||||
|
||||
def __init__(self, name, read_value=False):
|
||||
"""Initialize MultitypeFuncGraph."""
|
||||
MultitypeFuncGraph_.__init__(self, name)
|
||||
self.entries = list()
|
||||
if read_value:
|
||||
|
@ -523,6 +525,7 @@ class HyperMap(HyperMap_):
|
|||
"""
|
||||
|
||||
def __init__(self, ops=None):
|
||||
"""Initialize HyperMap."""
|
||||
self.ops = ops
|
||||
if ops:
|
||||
HyperMap_.__init__(self, ops)
|
||||
|
@ -586,6 +589,7 @@ class Map(Map_):
|
|||
"""
|
||||
|
||||
def __init__(self, ops=None):
|
||||
"""Initialize Map."""
|
||||
self.ops = ops
|
||||
if ops:
|
||||
Map_.__init__(self, ops)
|
||||
|
@ -610,6 +614,7 @@ class _ListAppend(ListAppend_):
|
|||
"""
|
||||
|
||||
def __init__(self, name):
|
||||
"""Initialize _ListAppend."""
|
||||
ListAppend_.__init__(self, name)
|
||||
|
||||
def __call__(self, *args):
|
||||
|
@ -628,6 +633,7 @@ class _Tail(Tail_):
|
|||
"""
|
||||
|
||||
def __init__(self, name):
|
||||
"""Initialize _Tail."""
|
||||
Tail_.__init__(self, name)
|
||||
|
||||
def __call__(self, *args):
|
||||
|
@ -641,6 +647,7 @@ class _ZipOperation(ZipOperation_):
|
|||
"""Generates a tuple of zip iterations for inputs."""
|
||||
|
||||
def __init__(self, name):
|
||||
"""Initialize _ZipOperation."""
|
||||
ZipOperation_.__init__(self, name)
|
||||
|
||||
def __call__(self, *args):
|
||||
|
|
|
@ -115,6 +115,7 @@ class _ClipByGlobalNorm(Cell):
|
|||
"""
|
||||
|
||||
def __init__(self, clip_norm=1.0, use_norm=None):
|
||||
"""Initialize _ClipByGlobalNorm."""
|
||||
super(_ClipByGlobalNorm, self).__init__()
|
||||
# Add interface. This parameter is not used at present
|
||||
if use_norm is not None:
|
||||
|
|
|
@ -135,7 +135,7 @@ def _axes_int_check(x1_shape, x2_shape, axes):
|
|||
raise ValueError(f"axes must be at least 0 for tensor dot, got {axes}")
|
||||
if axes == 0:
|
||||
# outer product, no input validation required
|
||||
return ([], [])
|
||||
return [], []
|
||||
if axes > len(x1_shape) or axes > len(x2_shape):
|
||||
raise ValueError(
|
||||
"Axes value too high for given input arrays dimensions.")
|
||||
|
@ -599,7 +599,7 @@ def _check_matmul_shapes(shape1, shape2):
|
|||
@constexpr
|
||||
def _tile_size(shape, out_shape, ndim):
|
||||
"""Returns tile_size such that shape*tile_size = out_shape"""
|
||||
size = [1]*ndim
|
||||
size = [1] * ndim
|
||||
for idx, (i, j) in enumerate(zip(shape, out_shape)):
|
||||
if i != j:
|
||||
size[idx] = j
|
||||
|
|
|
@ -450,7 +450,6 @@ def _tensor_getitem_by_tuple(data, tuple_index, op_name):
|
|||
|
||||
def _generate_indices_from_tuple_of_tensor(tuple_index, op_name):
|
||||
"""Generate an indices tensor from a tuple of tensor."""
|
||||
indices = None
|
||||
indexes_types = hyper_map(F.dtype, tuple_index)
|
||||
const_utils.check_types_valid(indexes_types, mstype.int_type, op_name)
|
||||
tensor_index_shape = hyper_map(F.shape, tuple_index)
|
||||
|
@ -615,11 +614,10 @@ def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value):
|
|||
"""Set a tensor item by a bool tensor with a tensor."""
|
||||
index_shape = F.shape(index)
|
||||
data_shape = F.shape(data)
|
||||
data_shape = const_utils.check_equal(data_shape, index_shape,
|
||||
"The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
|
||||
const_utils.check_equal(data_shape, index_shape,
|
||||
"The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
|
||||
size = F.shape_mul(F.shape(value))
|
||||
size = const_utils.check_equal(1, size,
|
||||
"When assign value is a tensor, its size should be {}, but current size is {}.")
|
||||
const_utils.check_equal(1, size, "When assign value is a tensor, its size should be {}, but current size is {}.")
|
||||
dtype = F.dtype(data)
|
||||
u_cast = F.cast(value, dtype)
|
||||
one_data = F.ones_like(data)
|
||||
|
|
|
@ -505,7 +505,6 @@ def generate_index_info_from_tuple_of_mixed_tensors(tensor_positions, tensor_ind
|
|||
tensor_index_continue_tag = _judge_order_continuous(tensor_positions)
|
||||
fancy_position = tensor_positions[0] if tensor_index_continue_tag else 0
|
||||
broadcast_shape = generate_broadcast_shape(tensor_indexes_shapes, op_name)
|
||||
index_tensor_new_shape, final_shape = [], []
|
||||
|
||||
final_shape = slice_shapes[:fancy_position] + broadcast_shape + slice_shapes[fancy_position:]
|
||||
index_tensor_new_shape = (1,) * len(slice_shapes[:fancy_position]) + \
|
||||
|
@ -619,14 +618,14 @@ def get_stride_info_from_tuple(data_shape, tuple_index):
|
|||
if ellipsis_count > 1:
|
||||
raise IndexError("An index can have only one ellipsis (...)")
|
||||
ellipsis_range_size = data_dim - tuple_index_len + 1
|
||||
begin_strides.extend([0] * (ellipsis_range_size))
|
||||
begin_strides.extend([0] * ellipsis_range_size)
|
||||
end_strides.extend(
|
||||
[shape for shape in data_shape[index_count: index_count + ellipsis_range_size]])
|
||||
step_strides.extend([1] * (ellipsis_range_size))
|
||||
step_strides.extend([1] * ellipsis_range_size)
|
||||
index_count = index_count + ellipsis_range_size
|
||||
else:
|
||||
raise IndexError("Not supported index data type, got ",
|
||||
index, " type is ", type(item))
|
||||
index, " type is ", type(index))
|
||||
for index in range(index_count, data_dim):
|
||||
begin_strides.append(0)
|
||||
end_strides.append(data_shape[index])
|
||||
|
|
|
@ -44,6 +44,7 @@ class _TupleAdd(base.TupleAdd_):
|
|||
"""
|
||||
|
||||
def __init__(self, name):
|
||||
"""Initialize _TupleAdd."""
|
||||
base.TupleAdd_.__init__(self, name)
|
||||
|
||||
def __call__(self, *args):
|
||||
|
|
|
@ -39,6 +39,7 @@ class _TupleSlice(base.TupleSlice_):
|
|||
"""
|
||||
|
||||
def __init__(self, name):
|
||||
"""Initialize _TupleSlice."""
|
||||
base.TupleSlice_.__init__(self, name)
|
||||
|
||||
def __call__(self, *args):
|
||||
|
@ -61,6 +62,7 @@ class _TupleGetItemTensor(base.TupleGetItemTensor_):
|
|||
"""
|
||||
|
||||
def __init__(self, name):
|
||||
"""Initialize _TupleGetItemTensor."""
|
||||
base.TupleGetItemTensor_.__init__(self, name)
|
||||
|
||||
def __call__(self, *args):
|
||||
|
|
|
@ -51,4 +51,4 @@ def _greater_equal_tensor(x, y):
|
|||
Returns:
|
||||
Tensor, return value by operator P.GreaterEqual.
|
||||
"""
|
||||
return F.tensor_ge(x, y)
|
||||
return F.tensor_ge(x, y)
|
||||
|
|
|
@ -51,4 +51,4 @@ def _less_equal_tensor(x, y):
|
|||
Returns:
|
||||
Tensor, return value by operator P.LessEqual.
|
||||
"""
|
||||
return F.tensor_le(x, y)
|
||||
return F.tensor_le(x, y)
|
||||
|
|
|
@ -49,4 +49,4 @@ def _logical_and_tensor(x, y):
|
|||
Returns:
|
||||
Tensor, Return logical and operation result of x and y.
|
||||
"""
|
||||
return F.logical_and(x, y)
|
||||
return F.logical_and(x, y)
|
||||
|
|
|
@ -68,10 +68,6 @@ class RegOp:
|
|||
|
||||
Args:
|
||||
op_name (str): Name of op.
|
||||
inputs (list): Inputs information of the op.
|
||||
outputs (list): Outputs information of the op.
|
||||
attr_ (list): Attribute information of the op.
|
||||
dtype_format_ (list): Dtype and format information of the op.
|
||||
"""
|
||||
|
||||
def __init__(self, op_name=""):
|
||||
|
@ -343,64 +339,13 @@ class AkgAscendRegOp(AkgRegOp):
|
|||
super(AkgAscendRegOp, self).__init__(op_name, "AiCore")
|
||||
|
||||
|
||||
class AiCPURegOp(RegOp):
|
||||
class AiCPURegOp(CpuRegOp):
|
||||
"""Class for AiCPU op info register"""
|
||||
|
||||
def __init__(self, op_name):
|
||||
super(AiCPURegOp, self).__init__(op_name)
|
||||
self.imply_type = "AiCPU"
|
||||
|
||||
def input(self, index=None, name=None, param_type=None, **kwargs):
|
||||
"""
|
||||
Register AiCPU op input information.
|
||||
|
||||
Args:
|
||||
index (int): Order of the input. Default: None.
|
||||
name (str): Name of the input. Default: None.
|
||||
param_type (str): Param type of the input. Default: None.
|
||||
kwargs (dict): Other information of the input.
|
||||
"""
|
||||
param_list = [index, name, param_type]
|
||||
key_list = ["index", "name", "param_type"]
|
||||
fn_list = [self._is_int, self._is_string, self._is_string]
|
||||
input_dict = self._check_param(param_list, key_list, fn_list, kwargs)
|
||||
self.inputs.append(input_dict)
|
||||
return self
|
||||
|
||||
def output(self, index=None, name=None, param_type=None, **kwargs):
|
||||
"""
|
||||
Register AiCPU op output information.
|
||||
|
||||
Args:
|
||||
index (int): Order of the output. Default: None.
|
||||
name (str): Name of the output. Default: None.
|
||||
param_type (str): Param type of the output. Default: None.
|
||||
kwargs (dict): Other information of the output.
|
||||
"""
|
||||
param_list = [index, name, param_type]
|
||||
key_list = ["index", "name", "param_type"]
|
||||
fn_list = [self._is_int, self._is_string, self._is_string]
|
||||
output_dict = self._check_param(param_list, key_list, fn_list, kwargs)
|
||||
self.outputs.append(output_dict)
|
||||
return self
|
||||
|
||||
def attr(self, name=None, value_type=None, value=None, **kwargs):
|
||||
"""
|
||||
Register AiCPU op attribute information.
|
||||
|
||||
Args:
|
||||
name (str): Name of the attribute. Default: None.
|
||||
value_type (str): Value type of the attribute. Default: None.
|
||||
value (str): Value of the attribute. Default: None.
|
||||
kwargs (dict): Other information of the attribute.
|
||||
"""
|
||||
param_list = [name, value_type, value]
|
||||
key_list = ["name", "type", "value"]
|
||||
fn_list = [self._is_string]
|
||||
attr_dict = self._check_param(param_list, key_list, fn_list, kwargs)
|
||||
self.attr_.append(attr_dict)
|
||||
return self
|
||||
|
||||
|
||||
class TBERegOp(RegOp):
|
||||
"""Class for TBE operator information register."""
|
||||
|
@ -419,7 +364,7 @@ class TBERegOp(RegOp):
|
|||
self.is_dynamic_format_ = False
|
||||
self.op_pattern_ = ""
|
||||
|
||||
def async_flag(self, async_flag):
|
||||
def async_flag(self, async_flag=False):
|
||||
"""
|
||||
Define the calculation efficiency of the operator, whether the asynchronous calculation is supported.
|
||||
|
||||
|
@ -441,7 +386,7 @@ class TBERegOp(RegOp):
|
|||
self.binfile_name_ = binfile_name
|
||||
return self
|
||||
|
||||
def compute_cost(self, compute_cost):
|
||||
def compute_cost(self, compute_cost=10):
|
||||
"""
|
||||
Define the calculation efficiency of operator, which refers to the value of the cost model
|
||||
in the tiling module.
|
||||
|
@ -464,7 +409,7 @@ class TBERegOp(RegOp):
|
|||
self.kernel_name_ = kernel_name
|
||||
return self
|
||||
|
||||
def partial_flag(self, partial_flag):
|
||||
def partial_flag(self, partial_flag=True):
|
||||
"""
|
||||
Define the calculation efficiency of operator, whether the partial calculation is supported.
|
||||
|
||||
|
@ -486,7 +431,7 @@ class TBERegOp(RegOp):
|
|||
self.reshape_type_ = reshape_type
|
||||
return self
|
||||
|
||||
def dynamic_shape(self, dynamic_shape):
|
||||
def dynamic_shape(self, dynamic_shape=False):
|
||||
"""
|
||||
Whether the operator supports dynamic shape.
|
||||
|
||||
|
@ -497,7 +442,7 @@ class TBERegOp(RegOp):
|
|||
self.dynamic_shape_ = dynamic_shape
|
||||
return self
|
||||
|
||||
def need_check_supported(self, need_check_supported):
|
||||
def need_check_supported(self, need_check_supported=False):
|
||||
"""
|
||||
Whether the operator need check supports.
|
||||
|
||||
|
@ -508,7 +453,7 @@ class TBERegOp(RegOp):
|
|||
self.need_check_supported_ = need_check_supported
|
||||
return self
|
||||
|
||||
def is_dynamic_format(self, is_dynamic_format):
|
||||
def is_dynamic_format(self, is_dynamic_format=False):
|
||||
"""
|
||||
Whether the operator need calop_select_format api.
|
||||
|
||||
|
|
|
@ -154,22 +154,9 @@ class RsqrtGrad(PrimitiveWithInfer):
|
|||
return x_dtype
|
||||
|
||||
|
||||
class SoftmaxGrad(PrimitiveWithInfer):
|
||||
class SoftmaxGrad(ReciprocalGrad):
|
||||
"""Performs grad of Softmax operation."""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize SoftmaxGrad"""
|
||||
|
||||
def infer_shape(self, x_shape, dout_shape):
|
||||
validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name)
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype, dout_dtype):
|
||||
args = {"x": x_dtype, "dout": dout_dtype}
|
||||
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
|
||||
return x_dtype
|
||||
|
||||
|
||||
class SqrtGrad(PrimitiveWithInfer):
|
||||
"""Performs grad of Sqrt operation."""
|
||||
|
@ -347,7 +334,6 @@ class Conv3DBackpropFilter(PrimitiveWithInfer):
|
|||
(32, 32, 4, 6, 2)
|
||||
"""
|
||||
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self,
|
||||
out_channel,
|
||||
|
@ -644,7 +630,7 @@ class DropoutGrad(PrimitiveWithInfer):
|
|||
|
||||
Args:
|
||||
keep_prob (float): The keep rate, between 0 and 1, e.g. keep_prob = 0.9,
|
||||
means dropping out 10% of input units.
|
||||
means dropping out 10% of input units. Default: 0.5.
|
||||
|
||||
Inputs:
|
||||
- **shape** (tuple[int]) - The shape of target mask.
|
||||
|
@ -1030,6 +1016,9 @@ class MaximumGrad(Primitive):
|
|||
def __init__(self, grad_x=True, grad_y=True):
|
||||
"""Initialize MaximumGrad"""
|
||||
|
||||
def __call__(self, x, y, dout):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MaxPoolGradWithArgmax(_PoolGrad):
|
||||
"""Computes the gradients of MaxPoolWithArgmax."""
|
||||
|
@ -1688,6 +1677,7 @@ class ROIAlignGrad(PrimitiveWithInfer):
|
|||
ROIAlignGrad operator.
|
||||
|
||||
Args:
|
||||
xdiff_shape (tuple): The diff shape.
|
||||
pooled_height (int): The output feature height.
|
||||
pooled_width (int): The output feature width.
|
||||
spatial_scale (float): The feature stride.
|
||||
|
@ -1892,7 +1882,7 @@ class StridedSliceGrad(PrimitiveWithInfer):
|
|||
|
||||
|
||||
class SoftplusGrad(PrimitiveWithInfer):
|
||||
"""Computes gradient for the Log Softmax activation."""
|
||||
"""Computes gradient for the Softplus activation."""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
|
|
|
@ -78,13 +78,14 @@ class ExtractImagePatches(PrimitiveWithInfer):
|
|||
|
||||
def infer_shape(self, input_x):
|
||||
"""infer shape"""
|
||||
if len(input_x) != 4:
|
||||
raise ValueError("The `input_x` should be a 4-D tensor, "
|
||||
f"but got a {len(input_x)}-D tensor whose shape is {input_x}")
|
||||
|
||||
in_batch, in_depth, in_row, in_col = input_x
|
||||
_, _, ksize_row, ksize_col = self.ksizes
|
||||
_, _, stride_row, stride_col = self.strides
|
||||
_, _, rate_row, rate_col = self.rates
|
||||
if len(input_x) != 4:
|
||||
raise ValueError("The `input_x` should be a 4-D tensor, "
|
||||
f"but got a {len(input_x)}-D tensor whose shape is {input_x}")
|
||||
|
||||
out_batch = in_batch
|
||||
out_depth = ksize_row * ksize_col * in_depth
|
||||
|
@ -124,7 +125,7 @@ class Range(PrimitiveWithInfer):
|
|||
start (float): If `limit` is `None`, the value acts as limit in the range and first entry
|
||||
defaults to `0`. Otherwise, it acts as first entry in the range.
|
||||
limit (float): Acts as upper limit of sequence. If `None`, defaults to the value of `start`
|
||||
while set the first entry of the range to `0`. It can not be equal to `start`.
|
||||
while set the first entry of the range to `0`. It can not be equal to `start`. Default: None.
|
||||
delta (float): Increment of the range. It can not be equal to zero. Default: 1.0.
|
||||
|
||||
Inputs:
|
||||
|
@ -134,9 +135,9 @@ class Range(PrimitiveWithInfer):
|
|||
Tensor, has the same shape and dtype as `input_x`.
|
||||
|
||||
Examples:
|
||||
>>> range = ops.Range(1.0, 8.0, 2.0)
|
||||
>>> range_op = ops.Range(1.0, 8.0, 2.0)
|
||||
>>> x = Tensor(np.array([1, 2, 3, 2]), mindspore.int32)
|
||||
>>> output = range(x)
|
||||
>>> output = range_op(x)
|
||||
>>> print(output)
|
||||
[3, 5, 7, 5]
|
||||
"""
|
||||
|
@ -906,7 +907,7 @@ class StackInit(PrimitiveWithInfer):
|
|||
at the top of the stack using `StackPop`. Finally, the stack should be destroyed with `StackDestroy`.
|
||||
|
||||
Args:
|
||||
index (int): The index of the stack.
|
||||
index (int): The index of the stack. Default: 1.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
|
@ -940,7 +941,7 @@ class StackPush(PrimitiveWithInfer):
|
|||
Please refer to the usage in source code of `StackInit`.
|
||||
|
||||
Args:
|
||||
index (int): The index of the stack.
|
||||
index (int): The index of the stack. Default: 1.
|
||||
|
||||
Inputs:
|
||||
- **input** (Tensor) - A tensor to be pushed onto the stack.
|
||||
|
@ -966,9 +967,9 @@ class StackPop(PrimitiveWithInfer):
|
|||
Please refer to the usage in source code of `StackInit`.
|
||||
|
||||
Args:
|
||||
index (int): The index of the stack.
|
||||
shape (tuple): The shape of the tensor at the top of the stack.
|
||||
dtype (mindspore.dtype): The type of the tensor at the top of the stack.
|
||||
index (int): The index of the stack. Default: 1.
|
||||
shape (tuple): The shape of the tensor at the top of the stack. Default: (1,).
|
||||
dtype (mindspore.dtype): The type of the tensor at the top of the stack. Default: mindspore.float32.
|
||||
|
||||
Outputs:
|
||||
- **output** (Tensor) - The tensor at the top of the stack.
|
||||
|
@ -1010,7 +1011,7 @@ class StackDestroy(PrimitiveWithInfer):
|
|||
Please refer to the usage in source code of `StackInit`.
|
||||
|
||||
Args:
|
||||
index (int): The index of the stack.
|
||||
index (int): The index of the stack. Default: 1.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
|
|
|
@ -136,9 +136,9 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer):
|
|||
|
||||
Examples:
|
||||
>>> x = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
|
||||
>>> min = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
|
||||
>>> max = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
|
||||
>>> output_tensor = MinMaxUpdatePerChannel(num_bits=8)(x, min, max)
|
||||
>>> min_value = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
|
||||
>>> max_value = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
|
||||
>>> output_tensor = MinMaxUpdatePerChannel(num_bits=8)(x, min_value, max_value)
|
||||
"""
|
||||
support_quant_bit = [4, 7, 8]
|
||||
ascend_support_x_rank = [2, 4]
|
||||
|
@ -519,7 +519,7 @@ class FakeQuantWithMinMaxVars(PrimitiveWithInfer):
|
|||
>>> max_tensor = Tensor(np.array([6]), mstype.float32)
|
||||
>>> output_tensor = FakeQuantWithMinMaxVars(num_bits=8, narrow_range=False)(
|
||||
... input_tensor, min_tensor, max_tensor)
|
||||
>>> output_tensor shape: (3, 16, 5, 5) data type: mstype.float32
|
||||
>>> output_tensor # shape: (3, 16, 5, 5) data type: mstype.float32
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
|
@ -581,9 +581,9 @@ class FakeQuantWithMinMaxVarsGradient(PrimitiveWithInfer):
|
|||
>>> max_tensor = Tensor(np.array([6]), mstype.float32)
|
||||
>>> x_gradient, min_gradient, max_gradient = FakeQuantWithMinMaxVarsGradient(num_bits=8,narrow_range=False)
|
||||
... (gradients, input_tensor, min_tensor, max_tensor)
|
||||
>>> x_gradient shape: (3, 16, 5, 5) data type: mstype.float32
|
||||
>>> min_gradient shape: (1,) data type: mstype.float32
|
||||
>>> max_gradient shape: (1,) data type: mstype.float32
|
||||
>>> x_gradient # shape: (3, 16, 5, 5) data type: mstype.float32
|
||||
>>> min_gradient # shape: (1,) data type: mstype.float32
|
||||
>>> max_gradient # shape: (1,) data type: mstype.float32
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
|
@ -642,7 +642,7 @@ class FakeQuantWithMinMaxVarsPerChannel(PrimitiveWithInfer):
|
|||
>>> max_tensor = Tensor(np.array([6, 1, 2, 3]), mstype.float32)
|
||||
>>> output_tensor = FakeQuantWithMinMaxVars(num_bits=8, narrow_range=False)(
|
||||
... input_tensor, min_tensor, max_tensor)
|
||||
>>> output_tensor shape: (3, 16, 3, 4) data type: mstype.float32
|
||||
>>> output_tensor # shape: (3, 16, 3, 4) data type: mstype.float32
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
|
@ -698,9 +698,9 @@ class FakeQuantWithMinMaxVarsPerChannelGradient(PrimitiveWithInfer):
|
|||
>>> x_gradient, min_gradient, max_gradient = FakeQuantWithMinMaxVarsPerChannelGradient(
|
||||
... num_bits=8, narrow_range=False)(
|
||||
... gradients, input_tensor, min_tensor, max_tensor)
|
||||
>>> x_gradient shape: (3, 16, 3, 4) data type: mstype.float32
|
||||
>>> min_gradient shape: (4,) data type: mstype.float32
|
||||
>>> max_gradient shape: (4,) data type: mstype.float32
|
||||
>>> x_gradient # shape: (3, 16, 3, 4) data type: mstype.float32
|
||||
>>> min_gradient # shape: (4,) data type: mstype.float32
|
||||
>>> max_gradient # shape: (4,) data type: mstype.float32
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
|
@ -728,6 +728,28 @@ class FakeQuantWithMinMaxVarsPerChannelGradient(PrimitiveWithInfer):
|
|||
return x_type, min_type, max_type
|
||||
|
||||
|
||||
def _fake_quant_per_infer_dtype(prim_name, x_type, min_type, max_type):
|
||||
if context.get_context('device_target') == "GPU":
|
||||
valid_dtypes = (mstype.float32,)
|
||||
else:
|
||||
valid_dtypes = (mstype.float16, mstype.float32)
|
||||
tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=prim_name),
|
||||
("x", "min", "max"),
|
||||
(x_type, min_type, max_type)))
|
||||
return x_type
|
||||
|
||||
|
||||
def _fake_quant_per_grad_infer_dtype(prim_name, dout_type, x_type, min_type, max_type):
|
||||
if context.get_context('device_target') == "GPU":
|
||||
valid_dtypes = (mstype.float32,)
|
||||
else:
|
||||
valid_dtypes = (mstype.float16, mstype.float32)
|
||||
tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=prim_name),
|
||||
("dout", "x", "min", "max"),
|
||||
(dout_type, x_type, min_type, max_type)))
|
||||
return dout_type
|
||||
|
||||
|
||||
class FakeQuantPerLayer(PrimitiveWithInfer):
|
||||
r"""
|
||||
Simulates the quantize and dequantize operations in training time.
|
||||
|
@ -797,19 +819,12 @@ class FakeQuantPerLayer(PrimitiveWithInfer):
|
|||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_type, min_type, max_type):
|
||||
if context.get_context('device_target') == "GPU":
|
||||
valid_dtypes = (mstype.float32,)
|
||||
else:
|
||||
valid_dtypes = (mstype.float16, mstype.float32)
|
||||
tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name),
|
||||
("x", "min", "max"),
|
||||
(x_type, min_type, max_type)))
|
||||
return x_type
|
||||
return _fake_quant_per_infer_dtype(self.name, x_type, min_type, max_type)
|
||||
|
||||
|
||||
class FakeQuantPerLayerGrad(PrimitiveWithInfer):
|
||||
r"""
|
||||
Performs grad of FakeQuantPerLayerGrad operation.
|
||||
Performs grad of FakeQuantPerLayer operation.
|
||||
|
||||
Examples:
|
||||
>>> fake_min_max_grad = FakeQuantPerLayerGrad()
|
||||
|
@ -852,14 +867,7 @@ class FakeQuantPerLayerGrad(PrimitiveWithInfer):
|
|||
return dout_shape
|
||||
|
||||
def infer_dtype(self, dout_type, x_type, min_type, max_type):
|
||||
if context.get_context('device_target') == "GPU":
|
||||
valid_dtypes = (mstype.float32,)
|
||||
else:
|
||||
valid_dtypes = (mstype.float16, mstype.float32)
|
||||
tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name),
|
||||
("dout", "x", "min", "max"),
|
||||
(dout_type, x_type, min_type, max_type)))
|
||||
return dout_type
|
||||
return _fake_quant_per_grad_infer_dtype(self.name, dout_type, x_type, min_type, max_type)
|
||||
|
||||
|
||||
class FakeQuantPerChannel(PrimitiveWithInfer):
|
||||
|
@ -946,19 +954,12 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
|
|||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_type, min_type, max_type):
|
||||
if context.get_context('device_target') == "GPU":
|
||||
valid_dtypes = (mstype.float32,)
|
||||
else:
|
||||
valid_dtypes = (mstype.float16, mstype.float32)
|
||||
tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name),
|
||||
("x", "min", "max"),
|
||||
(x_type, min_type, max_type)))
|
||||
return x_type
|
||||
return _fake_quant_per_infer_dtype(self.name, x_type, min_type, max_type)
|
||||
|
||||
|
||||
class FakeQuantPerChannelGrad(PrimitiveWithInfer):
|
||||
r"""
|
||||
Performs grad of FakeQuantPerChannelGrad operation.
|
||||
Performs grad of FakeQuantPerChannel operation.
|
||||
|
||||
Examples:
|
||||
>>> fqmmpc_grad = FakeQuantPerChannelGrad()
|
||||
|
@ -1001,14 +1002,7 @@ class FakeQuantPerChannelGrad(PrimitiveWithInfer):
|
|||
return dout_shape
|
||||
|
||||
def infer_dtype(self, dout_type, x_type, min_type, max_type):
|
||||
if context.get_context('device_target') == "GPU":
|
||||
valid_dtypes = (mstype.float32,)
|
||||
else:
|
||||
valid_dtypes = (mstype.float16, mstype.float32)
|
||||
tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name),
|
||||
("dout", "x", "min", "max"),
|
||||
(dout_type, x_type, min_type, max_type)))
|
||||
return dout_type
|
||||
return _fake_quant_per_grad_infer_dtype(self.name, dout_type, x_type, min_type, max_type)
|
||||
|
||||
|
||||
class BatchNormFold(PrimitiveWithInfer):
|
||||
|
@ -1298,7 +1292,7 @@ class BatchNormFold2(PrimitiveWithInfer):
|
|||
|
||||
class BatchNormFold2Grad(PrimitiveWithInfer):
|
||||
r"""
|
||||
Performs grad of CorrectionAddGrad operation.
|
||||
Performs grad of BatchNormFold2 operation.
|
||||
|
||||
Examples:
|
||||
>>> bnf2_grad = ops.BatchNormFold2Grad()
|
||||
|
@ -1386,7 +1380,7 @@ class BatchNormFoldD(PrimitiveWithInfer):
|
|||
|
||||
|
||||
class BatchNormFoldGradD(PrimitiveWithInfer):
|
||||
"""Performs grad of _BatchNormFoldGrad operation."""
|
||||
"""Performs grad of BatchNormFold operation."""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, epsilon=1e-5, is_training=True, freeze_bn=0):
|
||||
|
@ -1460,7 +1454,7 @@ class BatchNormFold2D(PrimitiveWithInfer):
|
|||
|
||||
|
||||
class BatchNormFold2GradD(PrimitiveWithInfer):
|
||||
"""Performs grad of CorrectionAddGrad operation."""
|
||||
"""Performs grad of BatchNormFold2 operation."""
|
||||
channel_axis = 1
|
||||
|
||||
@prim_attr_register
|
||||
|
@ -1678,7 +1672,6 @@ class WtsARQ(PrimitiveWithInfer):
|
|||
The WtsARQ(Weights Adaptive Range Quantization).
|
||||
|
||||
Args:
|
||||
axes (list): Specify channels for ARQ algorithm.
|
||||
num_bits (int): The bits num used for quantize.
|
||||
offset_flag (bool): Whether use offset for quantize.
|
||||
|
||||
|
|
|
@ -860,7 +860,7 @@ class Gather(Primitive):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize index_select"""
|
||||
"""Initialize Gather"""
|
||||
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
|
||||
|
||||
|
||||
|
@ -870,11 +870,10 @@ class GatherV2(PrimitiveWithCheck):
|
|||
Please use Gather instead.
|
||||
"""
|
||||
|
||||
|
||||
@deprecated("1.1", "Gather", True)
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize index_select"""
|
||||
"""Initialize GatherV2"""
|
||||
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
|
||||
|
||||
def __check__(self, params, indices, axis):
|
||||
|
@ -920,7 +919,7 @@ class SparseGatherV2(PrimitiveWithCheck):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize index_select"""
|
||||
"""Initialize SparseGatherV2"""
|
||||
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
|
||||
|
||||
|
||||
|
@ -940,7 +939,7 @@ class Padding(PrimitiveWithInfer):
|
|||
Extends the last dimension of the input tensor from 1 to pad_dim_size, by filling with 0.
|
||||
|
||||
Args:
|
||||
pad_dim_size (int): The value of the last dimension of `x` to be extended, which must be positive.
|
||||
pad_dim_size (int): The value of the last dimension of `x` to be extended, which must be positive. Default: 8.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. The rank of `x` must be at least 2.
|
||||
|
@ -1649,7 +1648,7 @@ class InvertPermutation(PrimitiveWithInfer):
|
|||
if mstype.issubclass_(x['dtype'], mstype.tensor):
|
||||
raise ValueError(f'For \'{self.name}\' the input value must be non-Tensor.')
|
||||
for shp in x_shp:
|
||||
if shp != []:
|
||||
if shp:
|
||||
x_rank = len(np.array(x_value, np.int64).shape)
|
||||
raise ValueError(f'For \'{self.name}\' the rank of input must be 1, but got {x_rank}.')
|
||||
for i, value in enumerate(x_value):
|
||||
|
@ -2086,7 +2085,7 @@ class UnsortedSegmentSum(PrimitiveWithInfer):
|
|||
validator.check_positive_int(segment_ids_shp_len, "rank of segment_ids", self.name)
|
||||
validator.check(f'rank of input_x', len(x_shp),
|
||||
'rank of segments_id', len(segment_ids_shp), Rel.GE, self.name)
|
||||
if (not -1 in x_shp and not -1 in segment_ids_shp):
|
||||
if -1 not in x_shp and -1 not in segment_ids_shp:
|
||||
# only validate when both shapes fully known
|
||||
for i, value in enumerate(segment_ids_shp):
|
||||
validator.check("ids[%d]" % i, value, 'input[%d]' % i, x_shp[i], Rel.EQ, self.name)
|
||||
|
@ -2176,7 +2175,7 @@ class UnsortedSegmentMin(PrimitiveWithCheck):
|
|||
validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name)
|
||||
num_segments_type = num_segments['dtype']
|
||||
validator.check_subclass("num_segments", num_segments_type, [mstype.number], self.name)
|
||||
if (not -1 in x_shape and not -1 in segment_ids_shape):
|
||||
if -1 not in x_shape and -1 not in segment_ids_shape:
|
||||
# only validate when both shapes fully known
|
||||
validator.check(f'first shape of input_x', x_shape[0],
|
||||
'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name)
|
||||
|
@ -2236,7 +2235,7 @@ class UnsortedSegmentMax(PrimitiveWithCheck):
|
|||
validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name)
|
||||
num_segments_type = num_segments['dtype']
|
||||
validator.check_subclass("num_segments", num_segments_type, [mstype.number], self.name)
|
||||
if (not -1 in x_shape and not -1 in segment_ids_shape):
|
||||
if -1 not in x_shape and -1 not in segment_ids_shape:
|
||||
# only validate when both shapes fully known
|
||||
validator.check(f'first shape of input_x', x_shape[0],
|
||||
'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name)
|
||||
|
@ -2701,20 +2700,20 @@ class Slice(PrimitiveWithInfer):
|
|||
>>> data = Tensor(np.array([[[1, 1, 1], [2, 2, 2]],
|
||||
... [[3, 3, 3], [4, 4, 4]],
|
||||
... [[5, 5, 5], [6, 6, 6]]]).astype(np.int32))
|
||||
>>> slice = ops.Slice()
|
||||
>>> output = slice(data, (1, 0, 0), (1, 1, 3))
|
||||
>>> slice_op = ops.Slice()
|
||||
>>> output = slice_op(data, (1, 0, 0), (1, 1, 3))
|
||||
>>> print(output)
|
||||
[[[3 3 3]]]
|
||||
>>> output = slice(data, (1, 0, 0), (1, 1, 2))
|
||||
>>> output = slice_op(data, (1, 0, 0), (1, 1, 2))
|
||||
>>> print(output)
|
||||
[[[3 3]]]
|
||||
>>> output = slice(data, (1, 0, 0), (1, 1, 1))
|
||||
>>> output = slice_op(data, (1, 0, 0), (1, 1, 1))
|
||||
>>> print(output)
|
||||
[[[3]]]
|
||||
>>> output = slice(data, (1, 1, 0), (1, 1, 3))
|
||||
>>> output = slice_op(data, (1, 1, 0), (1, 1, 3))
|
||||
>>> print(output)
|
||||
[[[4 4 4]]]
|
||||
>>> output = slice(data, (1, 0, 1), (1, 1, 2))
|
||||
>>> output = slice_op(data, (1, 0, 1), (1, 1, 2))
|
||||
>>> print(output)
|
||||
[[[3 3]]]
|
||||
"""
|
||||
|
@ -2786,6 +2785,7 @@ class ReverseV2(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, axis):
|
||||
"""Initialize ReverseV2."""
|
||||
validator.check_value_type('axis', axis, [list, tuple], self.name)
|
||||
for i, each in enumerate(axis):
|
||||
validator.check_value_type(f'axis[{i}]', each, [int], self.name)
|
||||
|
@ -2847,6 +2847,7 @@ class Rint(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize Rint."""
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
|
@ -2916,7 +2917,7 @@ class Select(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init"""
|
||||
"""Initialize Select."""
|
||||
self.init_prim_io_names(inputs=['condition', 'x', 'y'], outputs=['output'])
|
||||
|
||||
def infer_shape(self, cond_shape, x_shape, y_shape):
|
||||
|
@ -3099,9 +3100,9 @@ class StridedSlice(PrimitiveWithInfer):
|
|||
>>> # [6,6,6]
|
||||
>>> # ]
|
||||
>>> # ]
|
||||
>>> slice = ops.StridedSlice()
|
||||
>>> output = slice(input_x, (1, 0, 2), (3, 1, 3), (1, 1, 1))
|
||||
>>> # Take the call of operator " output = slice(input_x, (1, 0, 2), (3, 1, 3), (1, 1, 1)) " as an example,
|
||||
>>> strided_slice = ops.StridedSlice()
|
||||
>>> output = strided_slice(input_x, (1, 0, 2), (3, 1, 3), (1, 1, 1))
|
||||
>>> # Take this " output = strided_slice(input_x, (1, 0, 2), (3, 1, 3), (1, 1, 1)) " as an example,
|
||||
>>> # start = [1, 0, 2] , end = [3, 1, 3], stride = [1, 1, 1], Find a segment of (start, end),
|
||||
>>> # note that end is an open interval
|
||||
>>> # To facilitate understanding, this operator can be divided into three steps:
|
||||
|
@ -3144,7 +3145,7 @@ class StridedSlice(PrimitiveWithInfer):
|
|||
>>> # The final output after finishing is:
|
||||
[[[3], [5]]]
|
||||
>>> # anothor example like :
|
||||
>>> output = slice(input_x, (1, 0, 0), (2, 1, 3), (1, 1, 1))
|
||||
>>> output = strided_slice(input_x, (1, 0, 0), (2, 1, 3), (1, 1, 1))
|
||||
>>> print(output)
|
||||
[[[3. 3. 3.]]]
|
||||
"""
|
||||
|
@ -5300,7 +5301,7 @@ class Meshgrid(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, indexing="xy"):
|
||||
"""Init Meshgrid"""
|
||||
"""Initialize Meshgrid."""
|
||||
validator.check_value_type("indexing", indexing, (str), self.name)
|
||||
if indexing not in ("xy", "ij"):
|
||||
raise ValueError("indexing parameter must be either 'xy' or 'ij'")
|
||||
|
@ -5598,6 +5599,7 @@ class TransShape(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize TransShape."""
|
||||
self.__setattr_flag__ = True
|
||||
|
||||
def __infer__(self, x, shape):
|
||||
|
@ -5704,7 +5706,7 @@ class EmbeddingLookup(PrimitiveWithCheck):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize index_select"""
|
||||
"""Initialize EmbeddingLookup."""
|
||||
self.__setattr_flag__ = True
|
||||
self.init_prim_io_names(inputs=['params', 'indices', 'offset'],
|
||||
outputs=['output'])
|
||||
|
|
|
@ -109,6 +109,7 @@ class AllReduce(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP):
|
||||
"""Initialize AllReduce."""
|
||||
if not isinstance(op, type(ReduceOp.SUM)):
|
||||
raise TypeError("The operation of AllReduce should be str.")
|
||||
if not isinstance(_get_group(group), str):
|
||||
|
@ -182,6 +183,7 @@ class AllGather(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, group=GlobalComm.WORLD_COMM_GROUP):
|
||||
"""Initialize AllGather."""
|
||||
validator.check_value_type('group', _get_group(group), (str,), self.name)
|
||||
self.rank = get_rank(_get_group(group))
|
||||
self.rank_size = get_group_size(_get_group(group))
|
||||
|
@ -215,6 +217,7 @@ class _MiniStepAllGather(PrimitiveWithInfer):
|
|||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self, group=GlobalComm.WORLD_COMM_GROUP, grad_accumulation_step=None, mean_flag=None):
|
||||
"""Initialize _MiniStepAllGather."""
|
||||
validator.check_value_type('group', _get_group(group), (str,), self.name)
|
||||
self.rank = get_rank(_get_group(group))
|
||||
self.rank_size = get_group_size(_get_group(group))
|
||||
|
@ -247,7 +250,7 @@ class _HostAllGather(PrimitiveWithInfer):
|
|||
mpirun -output-filename log -merge-stderr-to-stdout -np 3 python test_host_all_gather.py
|
||||
|
||||
Args:
|
||||
group (Union[tuple[int],list[int]]): The rand_ids of communication group to work on.
|
||||
group (Union[tuple[int],list[int]]): The rand_ids of communication group to work on. Default: None.
|
||||
|
||||
Raises:
|
||||
TypeError: If group is not a list nor tuple, or elements of group are not int.
|
||||
|
@ -263,6 +266,7 @@ class _HostAllGather(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, group=None):
|
||||
"""Initialize _HostAllGather."""
|
||||
if group is None:
|
||||
raise ValueError(f"For '{self.name}' group must be set.")
|
||||
validator.check_value_type('group', group, (tuple, list), self.name)
|
||||
|
@ -338,6 +342,7 @@ class ReduceScatter(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP):
|
||||
"""Initialize ReduceScatter."""
|
||||
validator.check_value_type('op', op, (type(ReduceOp.SUM),), self.name)
|
||||
validator.check_value_type('group', _get_group(group), (str,), self.name)
|
||||
self.op = op
|
||||
|
@ -347,9 +352,11 @@ class ReduceScatter(PrimitiveWithInfer):
|
|||
self.add_prim_attr('fusion', 0)
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
if self.rank_size == 0:
|
||||
raise ValueError(f"For '{self.name}' rank_size can not be zero.")
|
||||
if x_shape[0] % self.rank_size != 0:
|
||||
raise ValueError(f"For '{self.name}' the first dimension of x should be divided by rank_size.")
|
||||
x_shape[0] = int(x_shape[0]/self.rank_size)
|
||||
x_shape[0] = int(x_shape[0] / self.rank_size)
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
|
@ -373,7 +380,7 @@ class _HostReduceScatter(PrimitiveWithInfer):
|
|||
Args:
|
||||
op (str): Specifies an operation used for element-wise reductions,
|
||||
like sum, max, avg. Default: ReduceOp.SUM.
|
||||
group (Union[tuple[int],list[int]]): The rand_ids of communication group to work on.
|
||||
group (Union[tuple[int],list[int]]): The rand_ids of communication group to work on. Default: None.
|
||||
|
||||
Raises:
|
||||
TypeError: If op is not a string and group is not a list nor tuple,
|
||||
|
@ -383,6 +390,7 @@ class _HostReduceScatter(PrimitiveWithInfer):
|
|||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self, op=ReduceOp.SUM, group=None):
|
||||
"""Initialize _HostReduceScatter."""
|
||||
if group is None:
|
||||
raise ValueError(f"For '{self.name}' group must be set.")
|
||||
validator.check_value_type('op', op, (type(ReduceOp.SUM),), self.name)
|
||||
|
@ -398,7 +406,7 @@ class _HostReduceScatter(PrimitiveWithInfer):
|
|||
def infer_shape(self, x_shape):
|
||||
if x_shape[0] % self.group_size != 0:
|
||||
raise ValueError(f"For '{self.name}' the first dimension of x should be divided by group_size.")
|
||||
x_shape[0] = int(x_shape[0]/self.group_size)
|
||||
x_shape[0] = int(x_shape[0] / self.group_size)
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
|
@ -465,6 +473,7 @@ class Broadcast(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, root_rank, group=GlobalComm.WORLD_COMM_GROUP):
|
||||
"""Initialize Broadcast."""
|
||||
validator.check_value_type('root_rank', root_rank, (int,), self.name)
|
||||
validator.check_value_type('group', _get_group(group), (str,), self.name)
|
||||
check_hcom_group_valid(group)
|
||||
|
@ -592,6 +601,7 @@ class _MirrorOperator(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, group=None, dev_num=None, mean_flag=None):
|
||||
"""Initialize _MirrorOperator."""
|
||||
self.group = group
|
||||
self.dev_num = dev_num
|
||||
self.mean_flag = mean_flag
|
||||
|
@ -621,6 +631,7 @@ class _MirrorMiniStepOperator(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, group=None, dev_num=None, mean_flag=None, grad_accumulation_step=None):
|
||||
"""Initialize _MirrorMiniStepOperator."""
|
||||
self.group = group
|
||||
self.dev_num = dev_num
|
||||
self.mean_flag = mean_flag
|
||||
|
@ -645,6 +656,7 @@ class _VirtualDiv(PrimitiveWithInfer):
|
|||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self, divisor=None):
|
||||
"""Initialize _VirtualDiv."""
|
||||
self.divisor = divisor
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
|
@ -661,7 +673,7 @@ class _VirtualAdd(PrimitiveWithInfer):
|
|||
"""Auto parallel virtual operator. Do nothing in forward, do Add in backward."""
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init"""
|
||||
"""Initialize _VirtualAdd."""
|
||||
|
||||
def infer_shape(self, x_shape, y_shape):
|
||||
return x_shape
|
||||
|
@ -679,7 +691,7 @@ class _VirtualDataset(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init"""
|
||||
"""Initialize _VirtualDataset."""
|
||||
|
||||
def infer_shape(self, *args):
|
||||
return args
|
||||
|
@ -696,12 +708,10 @@ class _VirtualAssignAdd(PrimitiveWithInfer):
|
|||
Auto parallel virtual operator. Do nothing in forward, do AssignAdd in backward. It is only for
|
||||
internal use of parallel modules and cannot be called by users.
|
||||
|
||||
Args:
|
||||
micro (int): MicroBatch. Default: 0.
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init"""
|
||||
"""Initialize _VirtualAssignAdd."""
|
||||
|
||||
def infer_shape(self, x_shape, y_shape):
|
||||
return x_shape
|
||||
|
@ -720,7 +730,7 @@ class _VirtualAccuGrad(PrimitiveWithInfer):
|
|||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init"""
|
||||
"""Initialize _VirtualAccuGrad."""
|
||||
|
||||
def infer_shape(self, x_shape, y_shape):
|
||||
return x_shape
|
||||
|
@ -745,6 +755,7 @@ class _MirrorMicroStepOperator(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, group=None, dev_num=None, mean_flag=None):
|
||||
"""Initialize _MirrorMicroStepOperator."""
|
||||
self.group = group
|
||||
self.dev_num = dev_num
|
||||
self.mean_flag = mean_flag
|
||||
|
@ -765,7 +776,7 @@ class _VirtualOutput(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init"""
|
||||
"""Initialize _VirtualOutput."""
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
return x_shape
|
||||
|
@ -784,7 +795,7 @@ class _GetTensorSlice(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize ChunkTensor"""
|
||||
"""Initialize _GetTensorSlice."""
|
||||
|
||||
def infer_value(self, x, dev_mat, tensor_map):
|
||||
from mindspore.parallel._tensor import _load_tensor
|
||||
|
|
|
@ -68,20 +68,20 @@ class GeSwitch(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init"""
|
||||
"""Initialize GeSwitch."""
|
||||
|
||||
def __call__(self, data, pred):
|
||||
raise NotImplementedError
|
||||
|
||||
def infer_shape(self, data, pred):
|
||||
validator.check_equal_int(len(pred), 0, "pred rank", self.name)
|
||||
return (data, data)
|
||||
return data, data
|
||||
|
||||
def infer_dtype(self, data_type, pred_type):
|
||||
validator.check_subclass(
|
||||
"data", data_type, (mstype.tensor,) + mstype.number_type, self.name)
|
||||
validator.check_tensor_dtype_valid("pred", pred_type, [mstype.bool_], self.name)
|
||||
return (data_type, data_type)
|
||||
return data_type, data_type
|
||||
|
||||
|
||||
class Merge(PrimitiveWithInfer):
|
||||
|
@ -108,13 +108,13 @@ class Merge(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init"""
|
||||
"""Initialize Merge."""
|
||||
|
||||
def __call__(self, *args):
|
||||
raise NotImplementedError
|
||||
|
||||
def infer_shape(self, inputs):
|
||||
return (inputs[0], [1])
|
||||
return inputs[0], [1]
|
||||
|
||||
def infer_dtype(self, inputs):
|
||||
args = {}
|
||||
|
@ -122,4 +122,4 @@ class Merge(PrimitiveWithInfer):
|
|||
args['inputs[%d]' % i] = item
|
||||
|
||||
validator.check_scalar_or_tensor_types_same(args, (mstype.bool_,) + mstype.number_type, self.name)
|
||||
return (inputs[0], mstype.int32)
|
||||
return inputs[0], mstype.int32
|
||||
|
|
|
@ -86,7 +86,7 @@ class ScalarSummary(Primitive):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init"""
|
||||
"""Initialize ScalarSummary."""
|
||||
self.add_prim_attr("side_effect_io", True)
|
||||
|
||||
|
||||
|
@ -125,7 +125,7 @@ class ImageSummary(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init"""
|
||||
"""Initialize ImageSummary."""
|
||||
self.add_prim_attr("side_effect_io", True)
|
||||
|
||||
def __infer__(self, name, value):
|
||||
|
@ -177,7 +177,7 @@ class TensorSummary(Primitive):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init"""
|
||||
"""Initialize TensorSummary."""
|
||||
self.add_prim_attr("side_effect_io", True)
|
||||
|
||||
|
||||
|
@ -217,7 +217,7 @@ class HistogramSummary(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init"""
|
||||
"""Initialize HistogramSummary."""
|
||||
self.add_prim_attr("side_effect_io", True)
|
||||
|
||||
def __infer__(self, name, value):
|
||||
|
@ -285,6 +285,7 @@ class InsertGradientOf(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, f):
|
||||
"""Initialize InsertGradientOf."""
|
||||
self.add_prim_attr('side_effect_backprop', True)
|
||||
self.f = f
|
||||
|
||||
|
@ -337,6 +338,7 @@ class HookBackward(PrimitiveWithInfer):
|
|||
"""
|
||||
|
||||
def __init__(self, hook_fn, cell_id=""):
|
||||
"""Initialize HookBackward."""
|
||||
super(HookBackward, self).__init__(self.__class__.__name__)
|
||||
self.add_prim_attr("cell_id", cell_id)
|
||||
self.init_attrs["cell_id"] = cell_id
|
||||
|
@ -398,6 +400,7 @@ class Print(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize Print."""
|
||||
self.add_prim_attr("side_effect_io", True)
|
||||
|
||||
def __call__(self, *args):
|
||||
|
|
|
@ -173,6 +173,7 @@ class TensorAdd(_MathBinaryOp):
|
|||
@deprecated("1.1", "Add", True)
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize TensorAdd."""
|
||||
_MathBinaryOp.__init__(self)
|
||||
|
||||
def infer_value(self, x, y):
|
||||
|
@ -310,7 +311,7 @@ class _Reduce(PrimitiveWithInfer):
|
|||
|
||||
Args:
|
||||
keep_dims (bool): If true, keep these reduced dimensions and the length is 1.
|
||||
If false, don't keep these dimensions.
|
||||
If false, don't keep these dimensions. Default: False.
|
||||
"""
|
||||
|
||||
__mindspore_signature__ = (
|
||||
|
@ -603,7 +604,7 @@ class ReduceMax(_Reduce):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, keep_dims=False):
|
||||
"""ReduceMax"""
|
||||
"""Initialize ReduceMax."""
|
||||
super(ReduceMax, self).__init__(keep_dims)
|
||||
self.__setattr_flag__ = True
|
||||
|
||||
|
@ -745,6 +746,7 @@ class CumProd(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, exclusive=False, reverse=False):
|
||||
"""Initialize CumProd."""
|
||||
cls_name = self.name
|
||||
self.exclusive = validator.check_value_type("exclusive", exclusive, [bool], cls_name)
|
||||
self.reverse = validator.check_value_type("reverse", reverse, [bool], cls_name)
|
||||
|
@ -803,6 +805,7 @@ class MatMul(PrimitiveWithCheck):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, transpose_a=False, transpose_b=False):
|
||||
"""Initialize MatMul."""
|
||||
self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output'])
|
||||
cls_name = self.name
|
||||
validator.check_value_type("transpose_a", transpose_a, [bool], cls_name)
|
||||
|
@ -908,6 +911,7 @@ class BatchMatMul(MatMul):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, transpose_a=False, transpose_b=False):
|
||||
"""Initialize BatchMatMul."""
|
||||
self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output'])
|
||||
cls_name = self.name
|
||||
validator.check_value_type("transpose_a", transpose_a, [bool], cls_name)
|
||||
|
@ -1016,13 +1020,14 @@ class AddN(Primitive):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize AddN."""
|
||||
self.init_prim_io_names(inputs=["inputs"], outputs=["sum"])
|
||||
|
||||
def check_elim(self, inputs):
|
||||
if len(inputs) != 1:
|
||||
return (False, None)
|
||||
return False, None
|
||||
if isinstance(inputs[0], Tensor):
|
||||
return (True, inputs[0])
|
||||
return True, inputs[0]
|
||||
raise TypeError("Expecting Tensor, got : {}".format(type(inputs[0])))
|
||||
|
||||
|
||||
|
@ -1068,14 +1073,15 @@ class AccumulateNV2(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize AccumulateNV2."""
|
||||
self.__setattr_flag__ = True
|
||||
self.init_prim_io_names(inputs=["inputs"], outputs=["sum"])
|
||||
|
||||
def check_elim(self, inputs):
|
||||
if len(inputs) != 1:
|
||||
return (False, None)
|
||||
return False, None
|
||||
if isinstance(inputs[0], Tensor):
|
||||
return (True, inputs[0])
|
||||
return True, inputs[0]
|
||||
raise TypeError("Expecting Tensor, got : {}".format(type(inputs[0])))
|
||||
|
||||
def infer_shape(self, inputs):
|
||||
|
@ -1732,7 +1738,7 @@ class Expm1(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize Exp"""
|
||||
"""Initialize Expm1."""
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['y'])
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
|
@ -1770,15 +1776,16 @@ class HistogramFixedWidth(PrimitiveWithInfer):
|
|||
|
||||
Examples:
|
||||
>>> x = Tensor([-1.0, 0.0, 1.5, 2.0, 5.0, 15], mindspore.float16)
|
||||
>>> range = Tensor([0.0, 5.0], mindspore.float16)
|
||||
>>> range_op = Tensor([0.0, 5.0], mindspore.float16)
|
||||
>>> hist = ops.HistogramFixedWidth(5)
|
||||
>>> output = hist(x, range)
|
||||
>>> output = hist(x, range_op)
|
||||
>>> print(output)
|
||||
[2 1 1 0 2]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, nbins, dtype='int32'):
|
||||
"""Initialize HistogramFixedWidth."""
|
||||
self.nbins = validator.check_value_type("nbins", nbins, [int], self.name)
|
||||
validator.check_int(nbins, 1, Rel.GE, "nbins", self.name)
|
||||
valid_values = ['int32', 'int64']
|
||||
|
@ -1825,6 +1832,7 @@ class Log(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize Log."""
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['y'])
|
||||
|
||||
def infer_shape(self, x):
|
||||
|
@ -1870,6 +1878,7 @@ class Log1p(Primitive):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize Log1p."""
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['y'])
|
||||
|
||||
|
||||
|
@ -2437,6 +2446,7 @@ class Floor(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize Floor."""
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['y'])
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
|
@ -2515,6 +2525,7 @@ class Ceil(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize Ceil."""
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['y'])
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
|
@ -3782,11 +3793,11 @@ class NMSWithMask(PrimitiveWithInfer):
|
|||
validator.check_positive_int(bboxes_shape[0], "bboxes.shape[0]", cls_name)
|
||||
validator.check_equal_int(bboxes_shape[1], 5, "bboxes.shape[1]", cls_name)
|
||||
num = bboxes_shape[0]
|
||||
return (bboxes_shape, (num,), (num,))
|
||||
return bboxes_shape, (num,), (num,)
|
||||
|
||||
def infer_dtype(self, bboxes_dtype):
|
||||
validator.check_tensor_dtype_valid("bboxes", bboxes_dtype, [mstype.float16, mstype.float32], self.name)
|
||||
return (bboxes_dtype, mstype.int32, mstype.bool_)
|
||||
return bboxes_dtype, mstype.int32, mstype.bool_
|
||||
|
||||
|
||||
class Abs(PrimitiveWithInfer):
|
||||
|
@ -3811,8 +3822,8 @@ class Abs(PrimitiveWithInfer):
|
|||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.array([-1.0, 1.0, 0.0]), mindspore.float32)
|
||||
>>> abs = ops.Abs()
|
||||
>>> output = abs(input_x)
|
||||
>>> abs_op = ops.Abs()
|
||||
>>> output = abs_op(input_x)
|
||||
>>> print(output)
|
||||
[1. 1. 0.]
|
||||
"""
|
||||
|
@ -3896,8 +3907,8 @@ class Round(PrimitiveWithInfer):
|
|||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.array([0.8, 1.5, 2.3, 2.5, -4.5]), mindspore.float32)
|
||||
>>> round = ops.Round()
|
||||
>>> output = round(input_x)
|
||||
>>> round_op = ops.Round()
|
||||
>>> output = round_op(input_x)
|
||||
>>> print(output)
|
||||
[ 1. 2. 2. 2. -4.]
|
||||
"""
|
||||
|
|
|
@ -172,6 +172,7 @@ class AdaptiveAvgPool2D(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, output_size):
|
||||
"""Initialize AdaptiveAvgPool2D."""
|
||||
validator.check_value_type("output_size", output_size, [int, tuple], self.name)
|
||||
if isinstance(output_size, tuple):
|
||||
validator.check_int(len(output_size), 2, Rel.EQ, 'output_size', self.name)
|
||||
|
@ -187,7 +188,7 @@ class AdaptiveAvgPool2D(PrimitiveWithInfer):
|
|||
out_size = [i if i else j for i, j in zipped]
|
||||
for item in out_size:
|
||||
validator.check_value_type("item of output_size", item, [int], self.name)
|
||||
self.add_prim_attr('output_size', (out_size))
|
||||
self.add_prim_attr('output_size', out_size)
|
||||
output_shape = x_shape[:len(x_shape) - len(out_size)] + out_size
|
||||
return output_shape
|
||||
|
||||
|
@ -238,6 +239,7 @@ class Softmax(Primitive):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, axis=-1):
|
||||
"""Initialize Softmax."""
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||
validator.check_value_type("axis", axis, [int, tuple], self.name)
|
||||
if isinstance(axis, int):
|
||||
|
@ -286,6 +288,7 @@ class LogSoftmax(Primitive):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, axis=-1):
|
||||
"""Initialize LogSoftmax."""
|
||||
validator.check_value_type("axis", axis, [int], self.name)
|
||||
|
||||
|
||||
|
@ -684,6 +687,7 @@ class HSwish(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize HSwish."""
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||
|
||||
def infer_shape(self, xshape):
|
||||
|
@ -728,6 +732,7 @@ class Sigmoid(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize Sigmoid."""
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||
|
||||
def infer_shape(self, input_x):
|
||||
|
@ -774,6 +779,7 @@ class HSigmoid(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize HSigmoid."""
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
|
@ -865,7 +871,6 @@ class InstanceNorm(PrimitiveWithInfer):
|
|||
momentum (float): The hyper parameter to compute moving average for running_mean and running_var
|
||||
(e.g. :math:`new\_running\_mean = momentum * running\_mean + (1 - momentum) * current\_mean`).
|
||||
Momentum value must be [0, 1]. Default: 0.1.
|
||||
data_format (str): The optional value for data format, is 'NCHW'. Default: "NCHW".
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The input of InstanceNorm, Tensor of shape :math:`(N, C)`,
|
||||
|
@ -928,6 +933,7 @@ class InstanceNorm(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, epsilon=1e-5, momentum=0.1):
|
||||
"""Initialize InstanceNorm."""
|
||||
self.init_prim_io_names(inputs=['x', 'gamma', 'beta', 'mean', 'variance'],
|
||||
outputs=['y', 'save_mean', 'save_variance'])
|
||||
self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name)
|
||||
|
@ -945,7 +951,7 @@ class InstanceNorm(PrimitiveWithInfer):
|
|||
validator.check("mean shape", mean, "gamma shape", gamma, Rel.EQ, self.name)
|
||||
save_mean_shape = gamma
|
||||
save_mean_shape[0] = save_mean_shape[0] * input_shape_norm[0]
|
||||
return (input_x, save_mean_shape, save_mean_shape)
|
||||
return input_x, save_mean_shape, save_mean_shape
|
||||
|
||||
def infer_dtype(self, input_x, gamma, beta, mean, variance):
|
||||
validator.check_tensor_dtype_valid("input_x", input_x, [mstype.float16, mstype.float32], self.name)
|
||||
|
@ -954,7 +960,7 @@ class InstanceNorm(PrimitiveWithInfer):
|
|||
args_moving = {"mean": mean, "variance": variance}
|
||||
valid_dtypes = [mstype.tensor_type(mstype.float32)]
|
||||
validator.check_types_same_and_valid(args_moving, valid_dtypes, self.name)
|
||||
return (input_x, gamma, gamma)
|
||||
return input_x, gamma, gamma
|
||||
|
||||
|
||||
class BNTrainingReduce(PrimitiveWithInfer):
|
||||
|
@ -992,15 +998,16 @@ class BNTrainingReduce(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize BNTrainingReduce."""
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['sum', 'square_sum'])
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
validator.check_equal_int(len(x_shape), 4, "x rank", self.name)
|
||||
return ([x_shape[1]], [x_shape[1]])
|
||||
return [x_shape[1]], [x_shape[1]]
|
||||
|
||||
def infer_dtype(self, x_type):
|
||||
validator.check_tensor_dtype_valid("x", x_type, [mstype.float16, mstype.float32], self.name)
|
||||
return (x_type, x_type)
|
||||
return x_type, x_type
|
||||
|
||||
|
||||
class BNTrainingUpdate(PrimitiveWithInfer):
|
||||
|
@ -1076,6 +1083,7 @@ class BNTrainingUpdate(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, isRef=True, epsilon=1e-5, factor=0.1):
|
||||
"""Initialize BNTrainingUpdate."""
|
||||
self.init_prim_io_names(inputs=['x', 'sum', 'square_sum', 'scale', 'b', 'mean', 'variance'],
|
||||
outputs=['y', 'running_mean', 'running_variance', 'save_mean', 'save_inv_variance'])
|
||||
validator.check_value_type("isRef", isRef, [bool], self.name)
|
||||
|
@ -1098,14 +1106,14 @@ class BNTrainingUpdate(PrimitiveWithInfer):
|
|||
validator.check("offset shape", b[0], "x_shape[1]", x[1], Rel.EQ, self.name)
|
||||
validator.check("mean shape", mean[0], "x_shape[1]", x[1], Rel.EQ, self.name)
|
||||
validator.check("variance shape", variance[0], "x_shape[1]", x[1], Rel.EQ, self.name)
|
||||
return (x, variance, variance, variance, variance)
|
||||
return x, variance, variance, variance, variance
|
||||
|
||||
def infer_dtype(self, x, sum, square_sum, scale, b, mean, variance):
|
||||
tuple(map(partial(validator.check_tensor_dtype_valid,
|
||||
valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
|
||||
("x", "sum", "square_sum", "scale", "b", "mean", "variance"),
|
||||
(x, sum, square_sum, scale, b, mean, variance)))
|
||||
return (x, variance, variance, variance, variance)
|
||||
return x, variance, variance, variance, variance
|
||||
|
||||
|
||||
class BatchNorm(PrimitiveWithInfer):
|
||||
|
@ -1203,6 +1211,7 @@ class BatchNorm(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, is_training=False, epsilon=1e-5, momentum=0.1, data_format="NCHW"):
|
||||
"""Initialize BatchNorm."""
|
||||
if is_training is False:
|
||||
self.set_signatures(tuple())
|
||||
validator.check_value_type('is_training', is_training, (bool,), self.name)
|
||||
|
@ -1224,13 +1233,13 @@ class BatchNorm(PrimitiveWithInfer):
|
|||
validator.check_equal_int(len(mean), 1, "mean rank", self.name)
|
||||
validator.check("mean shape", mean, "variance shape", variance, Rel.EQ, self.name)
|
||||
validator.check("mean shape", mean, "scale shape", scale, Rel.EQ, self.name)
|
||||
return (input_x, scale, scale, scale, scale)
|
||||
return input_x, scale, scale, scale, scale
|
||||
|
||||
def infer_dtype(self, input_x, scale, bias, mean, variance):
|
||||
validator.check_tensor_dtype_valid("input_x", input_x, [mstype.float16, mstype.float32], self.name)
|
||||
args = {"scale": scale, "bias": bias, "mean": mean, "variance": variance}
|
||||
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
|
||||
return (input_x, mstype.float32, mstype.float32, mstype.float32, mstype.float32)
|
||||
return input_x, mstype.float32, mstype.float32, mstype.float32, mstype.float32
|
||||
|
||||
|
||||
class Conv2D(Primitive):
|
||||
|
@ -1524,6 +1533,7 @@ class _Pool(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, kernel_size=1, strides=1, pad_mode="valid", data_format="NCHW"):
|
||||
"""Initialize _Pool."""
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||
validator.check_value_type('kernel_size', kernel_size, [int, tuple], self.name)
|
||||
validator.check_value_type('strides', strides, [int, tuple], self.name)
|
||||
|
@ -1642,6 +1652,7 @@ class MaxPool(_Pool):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, kernel_size=1, strides=1, pad_mode="valid", data_format="NCHW"):
|
||||
"""Initialize MaxPool."""
|
||||
super(MaxPool, self).__init__(kernel_size, strides, pad_mode, data_format)
|
||||
|
||||
|
||||
|
@ -1710,6 +1721,7 @@ class MaxPoolWithArgmax(_Pool):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, kernel_size=1, strides=1, pad_mode="valid", data_format="NCHW"):
|
||||
"""Initialize MaxPoolWithArgmax."""
|
||||
super(MaxPoolWithArgmax, self).__init__(kernel_size, strides, pad_mode, data_format)
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
|
@ -1784,6 +1796,7 @@ class MaxPool3D(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, kernel_size=1, strides=1, pad_mode="VALID", data_format="NCDHW"):
|
||||
"""Initialize MaxPool3D."""
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||
validator.check_value_type('kernel_size', kernel_size, [int, tuple], self.name)
|
||||
validator.check_value_type('strides', strides, [int, tuple], self.name)
|
||||
|
@ -1900,6 +1913,7 @@ class AvgPool(_Pool):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, kernel_size=1, strides=1, pad_mode="valid", data_format="NCHW"):
|
||||
"""Initialize AvgPool."""
|
||||
super(AvgPool, self).__init__(kernel_size, strides, pad_mode, data_format)
|
||||
|
||||
|
||||
|
@ -2074,6 +2088,7 @@ class Conv2DTranspose(Conv2DBackpropInput):
|
|||
@prim_attr_register
|
||||
def __init__(self, out_channel, kernel_size, pad_mode="valid", pad=0,
|
||||
pad_list=None, mode=1, stride=1, dilation=1, group=1, data_format="NCHW"):
|
||||
"""Initialize Conv2DTranspose."""
|
||||
super(Conv2DTranspose, self).__init__(out_channel, kernel_size, pad_mode, pad,
|
||||
pad_list, mode, stride, dilation, group, data_format)
|
||||
|
||||
|
@ -2118,6 +2133,7 @@ class BiasAdd(Primitive):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, data_format="NCHW"):
|
||||
"""Initialize BiasAdd."""
|
||||
self.init_prim_io_names(inputs=['x', 'b'], outputs=['output'])
|
||||
self.format = validator.check_string(data_format, ['NCHW', 'NHWC', 'NCDHW'], 'format', self.name)
|
||||
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
|
||||
|
@ -2177,6 +2193,7 @@ class TopK(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, sorted=False):
|
||||
"""Initialize TopK."""
|
||||
validator.check_value_type("sorted", sorted, [bool], self.name)
|
||||
self.init_prim_io_names(inputs=['input', 'k'],
|
||||
outputs=['values', 'indices'])
|
||||
|
@ -2340,12 +2357,12 @@ class SoftmaxCrossEntropyWithLogits(PrimitiveWithInfer):
|
|||
validator.check("logits_shape", logits_shape, "labels_shape", labels_shape, Rel.EQ, self.name)
|
||||
loss_shape = [logits_shape[0]]
|
||||
dlogits_shape = logits_shape
|
||||
return (loss_shape, dlogits_shape)
|
||||
return loss_shape, dlogits_shape
|
||||
|
||||
def infer_dtype(self, logits_type, labels_type):
|
||||
args = {"logits": logits_type, "labels": labels_type}
|
||||
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
|
||||
return (logits_type, logits_type)
|
||||
return logits_type, logits_type
|
||||
|
||||
|
||||
class SparseSoftmaxCrossEntropyWithLogits(PrimitiveWithInfer):
|
||||
|
@ -2401,6 +2418,7 @@ class SparseSoftmaxCrossEntropyWithLogits(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, is_grad=False):
|
||||
"""Initialize SparseSoftmaxCrossEntropyWithLogits."""
|
||||
validator.check_value_type('is_grad', is_grad, [bool], self.name)
|
||||
self.init_prim_io_names(inputs=['features', 'labels'], outputs=['output'])
|
||||
self.is_grad = is_grad
|
||||
|
@ -2473,6 +2491,7 @@ class ApplyMomentum(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, use_nesterov=False, use_locking=False, gradient_scale=1.0):
|
||||
"""Initialize ApplyMomentum."""
|
||||
self.use_nesterov = validator.check_bool(use_nesterov)
|
||||
self.use_locking = validator.check_bool(use_locking)
|
||||
validator.check_value_type('gradient_scale', gradient_scale, [float], self.name)
|
||||
|
@ -2543,6 +2562,7 @@ class SmoothL1Loss(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, beta=1.0):
|
||||
"""Initialize SmoothL1Loss."""
|
||||
validator.check_value_type('beta', beta, [float], self.name)
|
||||
validator.check('beta', beta, '', 0, Rel.GT, self.name)
|
||||
self.init_prim_io_names(inputs=['prediction', 'target'], outputs=['output'])
|
||||
|
@ -2636,6 +2656,7 @@ class DataFormatDimMap(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, src_format='NHWC', dst_format='NCHW'):
|
||||
"""Initialize DataFormatDimMap."""
|
||||
valid_values = ['NHWC', 'NCHW']
|
||||
self.src_format = validator.check_string(src_format, valid_values, "src_format", self.name)
|
||||
self.dst_format = validator.check_string(dst_format, valid_values, "dst_format", self.name)
|
||||
|
@ -2692,6 +2713,7 @@ class RNNTLoss(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, blank_label=0):
|
||||
"""Initialize RNNTLoss."""
|
||||
validator.check_value_type('blank_label', blank_label, [int], self.name)
|
||||
self.init_prim_io_names(inputs=['acts', 'labels', 'input_length', 'label_length'],
|
||||
outputs=['costs', 'grads'])
|
||||
|
@ -2706,7 +2728,7 @@ class RNNTLoss(PrimitiveWithInfer):
|
|||
validator.check('input_length size', input_length_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name)
|
||||
validator.check('label_length size', label_length_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name)
|
||||
costs_shape = (acts_shape[0],)
|
||||
return (costs_shape, acts_shape)
|
||||
return costs_shape, acts_shape
|
||||
|
||||
def infer_dtype(self, acts_type, labels_type, input_length_type, label_length_type):
|
||||
validator.check_tensor_dtype_valid("acts_type", acts_type, [mstype.float32, mstype.float16], self.name)
|
||||
|
@ -2714,7 +2736,7 @@ class RNNTLoss(PrimitiveWithInfer):
|
|||
valid_dtypes=(mstype.int32,), prim_name=self.name),
|
||||
("labels", "input_length", "label_length"),
|
||||
(labels_type, input_length_type, label_length_type)))
|
||||
return (acts_type, acts_type)
|
||||
return acts_type, acts_type
|
||||
|
||||
|
||||
class SGD(PrimitiveWithCheck):
|
||||
|
@ -2771,6 +2793,7 @@ class SGD(PrimitiveWithCheck):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, dampening=0.0, weight_decay=0.0, nesterov=False):
|
||||
"""Initialize SGD."""
|
||||
validator.check_value_type("nesterov", nesterov, [bool], self.name)
|
||||
if nesterov and dampening != 0:
|
||||
raise ValueError(f"Nesterov need zero dampening!")
|
||||
|
@ -2863,6 +2886,7 @@ class ApplyRMSProp(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, use_locking=False):
|
||||
"""Initialize ApplyRMSProp."""
|
||||
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
||||
self.init_prim_io_names(inputs=['var', 'mean_square', 'moment', 'learning_rate', 'grad',
|
||||
'rho', 'momentum', 'epsilon'], outputs=['output'])
|
||||
|
@ -2969,6 +2993,7 @@ class ApplyCenteredRMSProp(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, use_locking=False):
|
||||
"""Initialize ApplyCenteredRMSProp."""
|
||||
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
||||
self.add_prim_attr('side_effect_mem', True)
|
||||
|
||||
|
@ -3056,6 +3081,7 @@ class LayerNorm(Primitive):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, begin_norm_axis=1, begin_params_axis=1, epsilon=1e-7):
|
||||
"""Initialize LayerNorm."""
|
||||
validator.check_value_type('begin_norm_axis', begin_norm_axis, [int], self.name)
|
||||
validator.check_value_type('begin_params_axis', begin_params_axis, [int], self.name)
|
||||
validator.check_value_type('epsilon', epsilon, [float], self.name)
|
||||
|
@ -3102,6 +3128,7 @@ class L2Normalize(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, axis=0, epsilon=1e-4):
|
||||
"""Initialize L2Normalize."""
|
||||
axis = [axis] if isinstance(axis, int) else axis
|
||||
validator.check_value_type('axis', axis, [list, tuple], self.name)
|
||||
validator.check_value_type('epsilon', epsilon, [int, float], self.name)
|
||||
|
@ -3163,6 +3190,7 @@ class DropoutGenMask(Primitive):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, Seed0=0, Seed1=0):
|
||||
"""Initialize DropoutGenMask."""
|
||||
self.init_prim_io_names(inputs=['shape', 'keep_prob'], outputs=['output'])
|
||||
validator.check_value_type("Seed0", Seed0, [int], self.name)
|
||||
validator.check_value_type("Seed1", Seed1, [int], self.name)
|
||||
|
@ -3264,6 +3292,7 @@ class ResizeBilinear(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, size, align_corners=False):
|
||||
"""Initialize ResizeBilinear."""
|
||||
validator.check_value_type("size", size, [tuple, list], self.name)
|
||||
validator.check_equal_int(len(size), 2, "size len", self.name)
|
||||
for item in size:
|
||||
|
@ -3337,6 +3366,7 @@ class OneHot(Primitive):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, axis=-1):
|
||||
"""Initialize OneHot."""
|
||||
self.init_prim_io_names(inputs=['indices', 'depth', 'on_value', 'off_value'], outputs=['output'])
|
||||
validator.check_value_type("axis", axis, [int], self.name)
|
||||
|
||||
|
@ -3412,7 +3442,7 @@ class FastGelu(PrimitiveWithInfer):
|
|||
@deprecated("1.1", "FastGeLU", True)
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init FastGelu"""
|
||||
"""Initialize FastGelu."""
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||
|
||||
def infer_shape(self, input_x):
|
||||
|
@ -3457,7 +3487,7 @@ class FastGeLU(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init FastGeLU"""
|
||||
"""Initialize FastGeLU."""
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||
|
||||
def infer_shape(self, input_x):
|
||||
|
@ -3509,6 +3539,7 @@ class GetNext(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, types, shapes, output_num, shared_name):
|
||||
"""Initialize GetNext."""
|
||||
validator.check_value_type("types", types, [list, tuple], self.name)
|
||||
validator.check_value_type("shapes", shapes, [list, tuple], self.name)
|
||||
validator.check("types length", len(types), "shapes length", len(shapes), Rel.EQ, self.name)
|
||||
|
@ -3674,6 +3705,7 @@ class LSTM(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
|
||||
"""Initialize LSTM."""
|
||||
self.input_size = validator.check_positive_int(input_size, "input_size", self.name)
|
||||
self.hidden_size = validator.check_positive_int(hidden_size, "hidden_size", self.name)
|
||||
self.num_layers = validator.check_positive_int(num_layers, "num_layers", self.name)
|
||||
|
@ -3704,12 +3736,12 @@ class LSTM(PrimitiveWithInfer):
|
|||
# set arbitrary shape for reserved space
|
||||
reserved_shape = (1, 1)
|
||||
state_shape = (1, 1)
|
||||
return (y_shape, h_shape, c_shape, reserved_shape, state_shape)
|
||||
return y_shape, h_shape, c_shape, reserved_shape, state_shape
|
||||
|
||||
def infer_dtype(self, x_dtype, h_dtype, c_dtype, w_dtype):
|
||||
args = {'x': x_dtype, 'h': h_dtype, 'c': c_dtype, 'w': w_dtype}
|
||||
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float32, mstype.float16), self.name)
|
||||
return (x_dtype, x_dtype, x_dtype, x_dtype, x_dtype)
|
||||
return x_dtype, x_dtype, x_dtype, x_dtype, x_dtype
|
||||
|
||||
|
||||
class SigmoidCrossEntropyWithLogits(PrimitiveWithInfer):
|
||||
|
@ -4023,7 +4055,7 @@ class ComputeAccidentalHits(PrimitiveWithCheck):
|
|||
the weight is -FLOAT_MAX. FLOAT_MAX indicates the max value in the type of Float
|
||||
|
||||
Args:
|
||||
num_true (int): The number of target classes per training example.
|
||||
num_true (int): The number of target classes per training example. Default: 1.
|
||||
|
||||
Inputs:
|
||||
- **true_classes** (Tensor) - The target classes. With data type of int32 or int64
|
||||
|
@ -4252,6 +4284,7 @@ class Adam(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, use_locking=False, use_nesterov=False):
|
||||
"""Initialize Adam."""
|
||||
validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
||||
validator.check_value_type("use_nesterov", use_nesterov, [bool], self.name)
|
||||
self.add_prim_attr('side_effect_mem', True)
|
||||
|
@ -4368,6 +4401,7 @@ class AdamNoUpdateParam(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, use_locking=False, use_nesterov=False):
|
||||
"""Initialize AdamNoUpdateParam."""
|
||||
validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
||||
validator.check_value_type("use_nesterov", use_nesterov, [bool], self.name)
|
||||
|
||||
|
@ -4501,6 +4535,7 @@ class FusedSparseAdam(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, use_locking=False, use_nesterov=False):
|
||||
"""Initialize FusedSparseAdam."""
|
||||
validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
||||
validator.check_value_type("use_nesterov", use_nesterov, [bool], self.name)
|
||||
self.init_prim_io_names(inputs=['var', 'm', 'v', 'beta1_power', 'beta2_power', 'lr', 'beta1', 'beta2',
|
||||
|
@ -4649,6 +4684,7 @@ class FusedSparseLazyAdam(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, use_locking=False, use_nesterov=False):
|
||||
"""Initialize FusedSparseLazyAdam."""
|
||||
validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
||||
validator.check_value_type("use_nesterov", use_nesterov, [bool], self.name)
|
||||
self.init_prim_io_names(inputs=['var', 'm', 'v', 'beta1_power', 'beta2_power', 'lr', 'beta1', 'beta2',
|
||||
|
@ -4762,6 +4798,7 @@ class FusedSparseFtrl(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, lr, l1, l2, lr_power, use_locking=False):
|
||||
"""Initialize FusedSparseFtrl."""
|
||||
self.init_prim_io_names(inputs=['var', 'accum', 'linear', 'grad', 'indices'],
|
||||
outputs=['output'])
|
||||
self.add_prim_attr('side_effect_mem', True)
|
||||
|
@ -4878,6 +4915,7 @@ class FusedSparseProximalAdagrad(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, use_locking=False):
|
||||
"""Initialize FusedSparseProximalAdagrad"""
|
||||
self.init_prim_io_names(inputs=['var', 'accum', 'lr', 'l1', 'l2', 'grad', 'indices'],
|
||||
outputs=['output'])
|
||||
self.add_prim_attr('side_effect_mem', True)
|
||||
|
@ -4968,6 +5006,7 @@ class KLDivLoss(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, reduction='mean'):
|
||||
"""Initialize KLDivLoss."""
|
||||
self.reduction = validator.check_string(reduction, ['none', 'mean', 'sum'], 'reduction', self.name)
|
||||
|
||||
def infer_shape(self, x_shape, y_shape):
|
||||
|
@ -5055,6 +5094,7 @@ class BinaryCrossEntropy(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, reduction='mean'):
|
||||
"""Initialize BinaryCrossEntropy."""
|
||||
self.reduction = validator.check_string(reduction, ['none', 'mean', 'sum'], 'reduction', self.name)
|
||||
|
||||
def infer_shape(self, x_shape, y_shape, weight_shape):
|
||||
|
@ -5442,6 +5482,7 @@ class ApplyAdagrad(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, update_slots=True):
|
||||
"""Initialize ApplyAdagrad."""
|
||||
validator.check_value_type("update_slots", update_slots, [bool], self.name)
|
||||
self.add_prim_attr('side_effect_mem', True)
|
||||
|
||||
|
@ -5543,6 +5584,7 @@ class ApplyAdagradV2(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, epsilon, update_slots=True):
|
||||
"""Initialize ApplyAdagradV2."""
|
||||
validator.check_value_type("epsilon", epsilon, [float], self.name)
|
||||
validator.check_value_type("update_slots", update_slots, [bool], self.name)
|
||||
self.add_prim_attr('side_effect_mem', True)
|
||||
|
@ -5644,6 +5686,7 @@ class SparseApplyAdagrad(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, lr, update_slots=True, use_locking=False):
|
||||
"""Initialize SparseApplyAdagrad."""
|
||||
validator.check_value_type("lr", lr, [float], self.name)
|
||||
validator.check_is_float(lr, "lr", self.name)
|
||||
validator.check_value_type("update_slots", update_slots, [bool], self.name)
|
||||
|
@ -5748,6 +5791,7 @@ class SparseApplyAdagradV2(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, lr, epsilon, use_locking=False, update_slots=True):
|
||||
"""Initialize SparseApplyAdagradV2."""
|
||||
self.lr = validator.check_value_type("lr", lr, [float], self.name)
|
||||
self.epsilon = validator.check_value_type("epsilon", epsilon, [float], self.name)
|
||||
self.use_locking = validator.check_value_type("update_slots", update_slots, [bool], self.name)
|
||||
|
@ -5860,6 +5904,7 @@ class ApplyProximalAdagrad(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, use_locking=False):
|
||||
"""Initialize ApplyProximalAdagrad."""
|
||||
self.init_prim_io_names(inputs=['var', 'accum', 'lr', 'l1', 'l2', 'grad'],
|
||||
outputs=['var', 'accum'])
|
||||
self.add_prim_attr('side_effect_mem', True)
|
||||
|
@ -5985,6 +6030,7 @@ class SparseApplyProximalAdagrad(PrimitiveWithCheck):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, use_locking=False):
|
||||
"""Initialize SparseApplyProximalAdagrad."""
|
||||
self.init_prim_io_names(inputs=['var', 'accum', 'lr', 'l1', 'l2', 'grad', 'indices'],
|
||||
outputs=['var', 'accum'])
|
||||
self.add_prim_attr('side_effect_mem', True)
|
||||
|
@ -6095,7 +6141,7 @@ class ApplyAddSign(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"Initialize ApplyAddSign"
|
||||
"""Initialize ApplyAddSign."""
|
||||
self.add_prim_attr('side_effect_mem', True)
|
||||
|
||||
def infer_shape(self, var_shape, m_shape, lr_shape, alpha_shape, sign_decay_shape,
|
||||
|
@ -6225,7 +6271,7 @@ class ApplyPowerSign(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"Initialize ApplyPowerSign"
|
||||
"""Initialize ApplyPowerSign."""
|
||||
self.add_prim_attr('side_effect_mem', True)
|
||||
|
||||
def infer_shape(self, var_shape, m_shape, lr_shape, logbase_shape, sign_decay_shape,
|
||||
|
@ -6320,7 +6366,7 @@ class ApplyGradientDescent(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"Initialize ApplyGradientDescent"
|
||||
"""Initialize ApplyGradientDescent."""
|
||||
self.add_prim_attr('side_effect_mem', True)
|
||||
|
||||
def infer_shape(self, var_shape, alpha_shape, delta_shape):
|
||||
|
@ -6408,7 +6454,7 @@ class ApplyProximalGradientDescent(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"Initialize ApplyGradientDescent"
|
||||
"""Initialize ApplyGradientDescent."""
|
||||
self.add_prim_attr('side_effect_mem', True)
|
||||
|
||||
def infer_shape(self, var_shape, alpha_shape, l1_shape, l2_shape, delta_shape):
|
||||
|
@ -6495,7 +6541,7 @@ class LARSUpdate(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, epsilon=1e-05, hyperpara=0.001, use_clip=False):
|
||||
"""init"""
|
||||
"""Initialize LARSUpdate."""
|
||||
validator.check_value_type("epsilon", epsilon, [float], self.name)
|
||||
validator.check_value_type("hyperpara", hyperpara, [float], self.name)
|
||||
validator.check_value_type("use_clip", use_clip, [bool], self.name)
|
||||
|
@ -6601,6 +6647,7 @@ class ApplyFtrl(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, use_locking=False):
|
||||
"""Initialize ApplyFtrl."""
|
||||
self.init_prim_io_names(inputs=['var', 'accum', 'linear', 'grad', 'lr', 'l1', 'l2', 'lr_power'],
|
||||
outputs=['output'])
|
||||
self.add_prim_attr('side_effect_mem', True)
|
||||
|
@ -6706,6 +6753,7 @@ class SparseApplyFtrl(PrimitiveWithCheck):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, lr, l1, l2, lr_power, use_locking=False):
|
||||
"""Initialize SparseApplyFtrl."""
|
||||
validator.check_value_type("lr", lr, [float], self.name)
|
||||
validator.check_value_type("l1", l1, [float], self.name)
|
||||
validator.check_value_type("l2", l2, [float], self.name)
|
||||
|
@ -6819,6 +6867,7 @@ class SparseApplyFtrlV2(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, lr, l1, l2, l2_shrinkage, lr_power, use_locking=False):
|
||||
"""Initialize SparseApplyFtrlV2."""
|
||||
validator.check_value_type("lr", lr, [float], self.name)
|
||||
validator.check_value_type("l1", l1, [float], self.name)
|
||||
validator.check_value_type("l2", l2, [float], self.name)
|
||||
|
@ -6854,7 +6903,7 @@ class Dropout(PrimitiveWithCheck):
|
|||
|
||||
Args:
|
||||
keep_prob (float): The keep rate, between 0 and 1, e.g. keep_prob = 0.9,
|
||||
means dropping out 10% of input units.
|
||||
means dropping out 10% of input units. Default: 0.5.
|
||||
Seed0 (int): Seed0 value for random generating. Default: 0.
|
||||
Seed1 (int): Seed1 value for random generating. Default: 0.
|
||||
|
||||
|
@ -6884,6 +6933,7 @@ class Dropout(PrimitiveWithCheck):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, keep_prob=0.5, Seed0=0, Seed1=0):
|
||||
"""Initialize Dropout."""
|
||||
self.seed0 = validator.check_value_type("Seed0", Seed0, [int], self.name)
|
||||
self.seed1 = validator.check_value_type("Seed1", Seed1, [int], self.name)
|
||||
self.keep_prob = validator.check_float_range(keep_prob, 0, 1, Rel.INC_RIGHT, "keep_prob", self.name)
|
||||
|
@ -6938,6 +6988,7 @@ class Dropout2D(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, keep_prob=0.5):
|
||||
"""Initialize Dropout2D."""
|
||||
self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
|
||||
self.keep_prob = validator.check_float_range(keep_prob, 0.0, 1.0, Rel.INC_BOTH, "keep_prob", self.name)
|
||||
|
||||
|
@ -6995,6 +7046,7 @@ class Dropout3D(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, keep_prob=0.5):
|
||||
"""Initialize Dropout3D."""
|
||||
self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
|
||||
self.keep_prob = validator.check_float_range(keep_prob, 0.0, 1.0, Rel.INC_BOTH, "keep_prob", self.name)
|
||||
|
||||
|
@ -7081,6 +7133,7 @@ class CTCLoss(Primitive):
|
|||
@prim_attr_register
|
||||
def __init__(self, preprocess_collapse_repeated=False, ctc_merge_repeated=True,
|
||||
ignore_longer_outputs_than_inputs=False):
|
||||
"""Initialize CTCLoss."""
|
||||
self.init_prim_io_names(inputs=["inputs", "labels_indices", "labels_values", "sequence_length"],
|
||||
outputs=["loss", "gradient"])
|
||||
validator.check_value_type("preprocess_collapse_repeated", preprocess_collapse_repeated, [bool], self.name)
|
||||
|
@ -7140,6 +7193,7 @@ class CTCGreedyDecoder(PrimitiveWithCheck):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, merge_repeated=True):
|
||||
"""Initialize CTCGreedyDecoder."""
|
||||
self.merge_repeated = validator.check_value_type("merge_repeated", merge_repeated, [bool], self.name)
|
||||
|
||||
def check_shape(self, inputs_shape, sequence_length_shape):
|
||||
|
@ -7169,6 +7223,7 @@ class BasicLSTMCell(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, keep_prob=1.0, forget_bias=1.0, state_is_tuple=True, activation='tanh'):
|
||||
"""Initialize BasicLSTMCell."""
|
||||
self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
|
||||
self.keep_prob = validator.check_float_range(keep_prob, 0.0, 1.0, Rel.INC_BOTH, "keep_prob", self.name)
|
||||
self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
|
||||
|
@ -7195,7 +7250,7 @@ class BasicLSTMCell(PrimitiveWithInfer):
|
|||
ot_shape = c_shape
|
||||
tanhct_shape = c_shape
|
||||
|
||||
return (ct_shape, ht_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape)
|
||||
return ct_shape, ht_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape
|
||||
|
||||
def infer_dtype(self, x_dtype, h_dtype, c_dtype, w_dtype, b_dtype):
|
||||
tuple(map(partial(validator.check_tensor_dtype_valid,
|
||||
|
@ -7204,7 +7259,7 @@ class BasicLSTMCell(PrimitiveWithInfer):
|
|||
(x_dtype, h_dtype, w_dtype)))
|
||||
args = {"c_dtype": c_dtype, "b_dtype": b_dtype}
|
||||
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
|
||||
return (c_dtype, mstype.float16, c_dtype, c_dtype, c_dtype, c_dtype, c_dtype)
|
||||
return c_dtype, mstype.float16, c_dtype, c_dtype, c_dtype, c_dtype, c_dtype
|
||||
|
||||
|
||||
class DynamicRNN(PrimitiveWithInfer):
|
||||
|
@ -7315,6 +7370,7 @@ class DynamicRNN(PrimitiveWithInfer):
|
|||
activation='tanh',
|
||||
forget_bias=0.0,
|
||||
is_training=True):
|
||||
"""Initialize DynamicRNN."""
|
||||
self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
|
||||
self.cell_depth = validator.check_value_type("cell_depth", cell_depth, [int], self.name)
|
||||
self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
|
||||
|
@ -7480,6 +7536,7 @@ class DynamicGRUV2(PrimitiveWithInfer):
|
|||
gate_order="rzh",
|
||||
reset_after=True,
|
||||
is_training=True):
|
||||
"""Initialize DynamicGRUV2."""
|
||||
self.cell_depth = validator.check_value_type("cell_depth", cell_depth, [int], self.name)
|
||||
self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
|
||||
self.cell_clip = validator.check_value_type("cell_clip", cell_clip, [float], self.name)
|
||||
|
@ -7613,10 +7670,10 @@ class LRN(PrimitiveWithInfer):
|
|||
\sum_{c'=\max(0, c-n/2)}^{\min(N-1,c+n/2)}a_{c'}^2\right)^{-\beta}
|
||||
|
||||
Args:
|
||||
depth_radius (int): Half-width of the 1-D normalization window with the shape of 0-D.
|
||||
bias (float): An offset (usually positive to avoid dividing by 0).
|
||||
alpha (float): A scale factor, usually positive.
|
||||
beta (float): An exponent.
|
||||
depth_radius (int): Half-width of the 1-D normalization window with the shape of 0-D. Default: 5.
|
||||
bias (float): An offset (usually positive to avoid dividing by 0). Default: 1.0.
|
||||
alpha (float): A scale factor, usually positive. Default: 1.0.
|
||||
beta (float): An exponent. Default: 0.5.
|
||||
norm_region (str): Specifies normalization region. Options: "ACROSS_CHANNELS". Default: "ACROSS_CHANNELS".
|
||||
|
||||
Inputs:
|
||||
|
@ -7801,12 +7858,12 @@ class Conv3D(PrimitiveWithInfer):
|
|||
:math:`padding` is zero-padding added to both sides of the input.
|
||||
|
||||
Args:
|
||||
out_channels (int): The number of output channel :math:`C_{out}`.
|
||||
out_channel (int): The number of output channel :math:`C_{out}`.
|
||||
kernel_size (Union[int, tuple[int]]): The data type is int or a tuple of 3 integers. Specifies the depth, height
|
||||
and width of the 3D convolution window. Single int means the value is for the depth, height and the width
|
||||
of the kernel. A tuple of 3 ints means the first value is for the depth, height and the other is for the
|
||||
width of the kernel.
|
||||
mode (int): Modes for different convolutions. It is currently not used
|
||||
mode (int): Modes for different convolutions. It is currently not used. Default: 1.
|
||||
stride (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
|
||||
the depth, height and width of movement are both strides, or a tuple of three int numbers that
|
||||
represent depth, height and width of movement respectively. Default: 1.
|
||||
|
|
|
@ -70,6 +70,7 @@ class Assign(Primitive):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize Assign."""
|
||||
self.init_prim_io_names(inputs=['ref', 'value'], outputs=['output'])
|
||||
self.add_prim_attr('side_effect_mem', True)
|
||||
|
||||
|
@ -111,6 +112,7 @@ class InplaceAssign(PrimitiveWithInfer):
|
|||
@deprecated("1.3", "Assign", False)
|
||||
@ prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize InplaceAssign."""
|
||||
self.init_prim_io_names(inputs=['x', 'y', 'z'], outputs=['output'])
|
||||
|
||||
def infer_shape(self, x, y, z):
|
||||
|
@ -137,6 +139,7 @@ class Load(PrimitiveWithCheck):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize Load."""
|
||||
self.init_prim_io_names(inputs=['ref', 'u'], outputs=['output'])
|
||||
|
||||
def check_dtype(self, variable):
|
||||
|
@ -178,8 +181,9 @@ class BoundingBoxEncode(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0)):
|
||||
validator.check_value_type('means', means, (tuple), self.name)
|
||||
validator.check_value_type('stds', stds, (tuple), self.name)
|
||||
"""Initialize BoundingBoxEncode."""
|
||||
validator.check_value_type('means', means, tuple, self.name)
|
||||
validator.check_value_type('stds', stds, tuple, self.name)
|
||||
for i, value in enumerate(means):
|
||||
validator.check_value_type("means[%d]" % i, value, [float], self.name)
|
||||
for i, value in enumerate(stds):
|
||||
|
@ -241,8 +245,9 @@ class BoundingBoxDecode(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, max_shape, means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0), wh_ratio_clip=0.016):
|
||||
validator.check_value_type('means', means, (tuple), self.name)
|
||||
validator.check_value_type('stds', stds, (tuple), self.name)
|
||||
"""Initialize BoundingBoxDecode."""
|
||||
validator.check_value_type('means', means, tuple, self.name)
|
||||
validator.check_value_type('stds', stds, tuple, self.name)
|
||||
for i, value in enumerate(means):
|
||||
validator.check_value_type("means[%d]" % i, value, [float], self.name)
|
||||
for i, value in enumerate(stds):
|
||||
|
@ -313,6 +318,7 @@ class CheckValid(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize CheckValid."""
|
||||
self.init_prim_io_names(inputs=['bboxes', 'img_metas'], outputs=['output'])
|
||||
|
||||
def infer_shape(self, bboxes_shape, metas_shape):
|
||||
|
@ -372,6 +378,7 @@ class IOU(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, mode='iou'):
|
||||
"""Initialize IOU."""
|
||||
if mode not in {'iou', 'iof'}:
|
||||
raise KeyError("Mode only support 'iou' or 'iof'.")
|
||||
self.init_prim_io_names(inputs=['anchor_boxes', 'gt_boxes'], outputs=['overlap'])
|
||||
|
@ -407,6 +414,7 @@ class Partial(Primitive):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize Partial."""
|
||||
self.add_prim_attr('side_effect_propagate', 1)
|
||||
|
||||
def __call__(self, *args):
|
||||
|
@ -473,6 +481,7 @@ class Depend(Primitive):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize Depend."""
|
||||
self.add_prim_attr('side_effect_propagate', 1)
|
||||
|
||||
def __call__(self, value, expr):
|
||||
|
@ -603,6 +612,7 @@ class ConfusionMatrix(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, num_classes, dtype="int32"):
|
||||
"""Initialize ConfusionMatrix."""
|
||||
validator.check_value_type("num_classes", num_classes, [int], self.name)
|
||||
validator.check_value_type("dtype", dtype, [str], self.name)
|
||||
|
||||
|
@ -781,6 +791,7 @@ class identity(Primitive):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize identity."""
|
||||
self.add_prim_attr('side_effect_propagate', 1)
|
||||
|
||||
def __call__(self, x):
|
||||
|
|
|
@ -336,7 +336,7 @@ class UniformInt(PrimitiveWithInfer):
|
|||
return out
|
||||
|
||||
|
||||
class UniformReal(PrimitiveWithInfer):
|
||||
class UniformReal(StandardNormal):
|
||||
r"""
|
||||
Produces random floating-point values i, uniformly distributed to the interval [0, 1).
|
||||
|
||||
|
@ -367,27 +367,6 @@ class UniformReal(PrimitiveWithInfer):
|
|||
(2, 2)
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, seed=0, seed2=0):
|
||||
"""Initialize UniformReal"""
|
||||
self.init_prim_io_names(inputs=['shape'], outputs=['output'])
|
||||
self.add_prim_attr('side_effect_mem', True)
|
||||
Validator.check_non_negative_int(seed, "seed", self.name)
|
||||
Validator.check_non_negative_int(seed2, "seed2", self.name)
|
||||
|
||||
def __infer__(self, shape):
|
||||
shape_v = shape["value"]
|
||||
if shape_v is None:
|
||||
raise ValueError(f"For {self.name}, shape must be const.")
|
||||
Validator.check_value_type("shape", shape_v, [tuple], self.name)
|
||||
for i, shape_i in enumerate(shape_v):
|
||||
Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
|
||||
out = {
|
||||
'shape': shape_v,
|
||||
'dtype': mstype.float32,
|
||||
'value': None}
|
||||
return out
|
||||
|
||||
|
||||
class RandomChoiceWithMask(PrimitiveWithInfer):
|
||||
"""
|
||||
|
@ -445,11 +424,11 @@ class RandomChoiceWithMask(PrimitiveWithInfer):
|
|||
def infer_shape(self, x_shape):
|
||||
Validator.check_int(len(x_shape), 1, Rel.GE, "input_x rank", self.name)
|
||||
Validator.check_int(len(x_shape), 5, Rel.LE, "input_x rank", self.name)
|
||||
return ([self.count, len(x_shape)], [self.count])
|
||||
return [self.count, len(x_shape)], [self.count]
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
Validator.check_tensor_dtype_valid('x', x_dtype, [mstype.bool_], self.name)
|
||||
return (mstype.int32, mstype.bool_)
|
||||
return mstype.int32, mstype.bool_
|
||||
|
||||
|
||||
class RandomCategorical(PrimitiveWithInfer):
|
||||
|
@ -539,12 +518,12 @@ class Multinomial(PrimitiveWithInfer):
|
|||
seed2 (int): Random seed2, must be non-negative. Default: 0.
|
||||
|
||||
Inputs:
|
||||
- **input** (Tensor[float32]) - the input tensor containing the cumsum of probabilities, must be 1 or 2
|
||||
- **x** (Tensor[float32]) - the input tensor containing the cumsum of probabilities, must be 1 or 2
|
||||
dimensions.
|
||||
- **num_samples** (int32) - number of samples to draw.
|
||||
|
||||
Outputs:
|
||||
Tensor with the same rows as input, each row has num_samples sampled indices.
|
||||
Tensor with the same rows as `x`, each row has num_samples sampled indices.
|
||||
|
||||
Raises:
|
||||
TypeError: If neither `seed` nor `seed2` is an int.
|
||||
|
@ -555,16 +534,16 @@ class Multinomial(PrimitiveWithInfer):
|
|||
``GPU``
|
||||
|
||||
Examples:
|
||||
>>> input = Tensor([0., 9., 4., 0.], mstype.float32)
|
||||
>>> x = Tensor([0., 9., 4., 0.], mstype.float32)
|
||||
>>> multinomial = ops.Multinomial(seed=10)
|
||||
>>> output = multinomial(input, 2)
|
||||
>>> output = multinomial(x, 2)
|
||||
>>> print(output)
|
||||
[2 1]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, seed=0, seed2=0):
|
||||
"""init"""
|
||||
"""Initialize Multinomial."""
|
||||
Validator.check_non_negative_int(seed, "seed", self.name)
|
||||
Validator.check_non_negative_int(seed2, "seed2", self.name)
|
||||
self.init_prim_io_names(inputs=['input', 'num_sample'], outputs=['output'])
|
||||
|
@ -655,11 +634,11 @@ class UniformCandidateSampler(PrimitiveWithInfer):
|
|||
Validator.check_subclass("true_classes_type", true_classes_type, mstype.tensor, self.name)
|
||||
Validator.check_tensor_dtype_valid("true_classes_type", true_classes_type,
|
||||
(mstype.int32, mstype.int64), self.name)
|
||||
return (true_classes_type, mstype.float32, mstype.float32)
|
||||
return true_classes_type, mstype.float32, mstype.float32
|
||||
|
||||
def infer_shape(self, true_classes_shape):
|
||||
Validator.check("true_class.shape[1]", true_classes_shape[1], "num_true", self.num_true, Rel.EQ, self.name)
|
||||
return ([self.num_sampled], true_classes_shape, [self.num_sampled])
|
||||
return [self.num_sampled], true_classes_shape, [self.num_sampled]
|
||||
|
||||
|
||||
class LogUniformCandidateSampler(PrimitiveWithInfer):
|
||||
|
@ -675,7 +654,7 @@ class LogUniformCandidateSampler(PrimitiveWithInfer):
|
|||
all sampled classes in a batch are unique. Default: True.
|
||||
range_max (int): The number of possible classes. When `unique` is True,
|
||||
`range_max` must be greater than or equal to `num_sampled`. Default: 5.
|
||||
seed (int): Random seed, must be non-negative.
|
||||
seed (int): Random seed, must be non-negative. Default: 0.
|
||||
|
||||
Inputs:
|
||||
- **true_classes** (Tensor) - The target classes. With data type of int64 and shape [batch_size, num_true].
|
||||
|
|
|
@ -59,7 +59,7 @@ class SparseToDense(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize index_select"""
|
||||
"""Initialize SparseToDense."""
|
||||
self.init_prim_io_names(inputs=['indices', 'values', 'dense_shape'], outputs=['output'])
|
||||
|
||||
def __infer__(self, indices, values, sparse_shape):
|
||||
|
|
|
@ -58,8 +58,9 @@ class BondForce(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, bond_numbers, atom_numbers):
|
||||
validator.check_value_type('bond_numbers', bond_numbers, (int), self.name)
|
||||
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
|
||||
"""Initialize BondForce."""
|
||||
validator.check_value_type('bond_numbers', bond_numbers, int, self.name)
|
||||
validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
|
||||
self.bond_numbers = bond_numbers
|
||||
self.atom_numbers = atom_numbers
|
||||
self.add_prim_attr('bond_numbers', self.bond_numbers)
|
||||
|
@ -131,8 +132,9 @@ class BondEnergy(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, bond_numbers, atom_numbers):
|
||||
validator.check_value_type('bond_numbers', bond_numbers, (int), self.name)
|
||||
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
|
||||
"""Initialize BondEnergy."""
|
||||
validator.check_value_type('bond_numbers', bond_numbers, int, self.name)
|
||||
validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
|
||||
self.bond_numbers = bond_numbers
|
||||
self.atom_numbers = atom_numbers
|
||||
self.add_prim_attr('bond_numbers', self.bond_numbers)
|
||||
|
@ -199,8 +201,9 @@ class BondAtomEnergy(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, bond_numbers, atom_numbers):
|
||||
validator.check_value_type('bond_numbers', bond_numbers, (int), self.name)
|
||||
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
|
||||
"""Initialize BondAtomEnergy."""
|
||||
validator.check_value_type('bond_numbers', bond_numbers, int, self.name)
|
||||
validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
|
||||
self.bond_numbers = bond_numbers
|
||||
self.atom_numbers = atom_numbers
|
||||
self.add_prim_attr('bond_numbers', self.bond_numbers)
|
||||
|
@ -266,8 +269,9 @@ class BondForceWithAtomEnergy(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, bond_numbers, atom_numbers):
|
||||
validator.check_value_type('bond_numbers', bond_numbers, (int), self.name)
|
||||
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
|
||||
"""Initialize BondForceWithAtomEnergy."""
|
||||
validator.check_value_type('bond_numbers', bond_numbers, int, self.name)
|
||||
validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
|
||||
self.bond_numbers = bond_numbers
|
||||
self.atom_numbers = atom_numbers
|
||||
self.add_prim_attr('bond_numbers', self.bond_numbers)
|
||||
|
@ -346,8 +350,9 @@ class BondForceWithAtomVirial(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, bond_numbers, atom_numbers):
|
||||
validator.check_value_type('bond_numbers', bond_numbers, (int), self.name)
|
||||
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
|
||||
"""Initialize BondForceWithAtomVirial."""
|
||||
validator.check_value_type('bond_numbers', bond_numbers, int, self.name)
|
||||
validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
|
||||
self.bond_numbers = bond_numbers
|
||||
self.atom_numbers = atom_numbers
|
||||
self.add_prim_attr('bond_numbers', self.bond_numbers)
|
||||
|
@ -457,7 +462,8 @@ class DihedralForce(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, dihedral_numbers):
|
||||
validator.check_value_type('dihedral_numbers', dihedral_numbers, (int), self.name)
|
||||
"""Initialize DihedralForce."""
|
||||
validator.check_value_type('dihedral_numbers', dihedral_numbers, int, self.name)
|
||||
self.dihedral_numbers = dihedral_numbers
|
||||
self.init_prim_io_names(inputs=['uint_crd_f', 'scaler_f', 'atom_a', 'atom_b', 'atom_c', 'atom_d', 'ipn', 'pk',
|
||||
'gamc', 'gams', 'pn'],
|
||||
|
@ -549,7 +555,8 @@ class DihedralEnergy(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, dihedral_numbers):
|
||||
validator.check_value_type('dihedral_numbers', dihedral_numbers, (int), self.name)
|
||||
"""Initialize DihedralEnergy."""
|
||||
validator.check_value_type('dihedral_numbers', dihedral_numbers, int, self.name)
|
||||
self.dihedral_numbers = dihedral_numbers
|
||||
self.init_prim_io_names(inputs=['uint_crd_f', 'scaler_f', 'atom_a', 'atom_b', 'atom_c', 'atom_d', 'ipn', 'pk',
|
||||
'gamc', 'gams', 'pn'],
|
||||
|
@ -639,7 +646,8 @@ class DihedralAtomEnergy(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, dihedral_numbers):
|
||||
validator.check_value_type('dihedral_numbers', dihedral_numbers, (int), self.name)
|
||||
"""Initialize DihedralAtomEnergy."""
|
||||
validator.check_value_type('dihedral_numbers', dihedral_numbers, int, self.name)
|
||||
self.dihedral_numbers = dihedral_numbers
|
||||
self.init_prim_io_names(inputs=['uint_crd_f', 'scaler_f', 'atom_a', 'atom_b', 'atom_c', 'atom_d', 'ipn', 'pk',
|
||||
'gamc', 'gams', 'pn'],
|
||||
|
@ -729,7 +737,8 @@ class DihedralForceWithAtomEnergy(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, dihedral_numbers):
|
||||
validator.check_value_type('dihedral_numbers', dihedral_numbers, (int), self.name)
|
||||
"""Initialize DihedralForceWithAtomEnergy."""
|
||||
validator.check_value_type('dihedral_numbers', dihedral_numbers, int, self.name)
|
||||
self.dihedral_numbers = dihedral_numbers
|
||||
self.init_prim_io_names(inputs=['uint_crd_f', 'scaler_f', 'atom_a', 'atom_b', 'atom_c', 'atom_d', 'ipn', 'pk',
|
||||
'gamc', 'gams', 'pn'],
|
||||
|
@ -826,7 +835,8 @@ class AngleForce(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, angle_numbers):
|
||||
validator.check_value_type('angle_numbers', angle_numbers, (int), self.name)
|
||||
"""Initialize AngleForce."""
|
||||
validator.check_value_type('angle_numbers', angle_numbers, int, self.name)
|
||||
self.angle_numbers = angle_numbers
|
||||
self.init_prim_io_names(inputs=['uint_crd_f', 'scaler_f', 'atom_a', 'atom_b', 'atom_c', 'angle_k',
|
||||
'angle_theta0'],
|
||||
|
@ -902,7 +912,8 @@ class AngleEnergy(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, angle_numbers):
|
||||
validator.check_value_type('angle_numbers', angle_numbers, (int), self.name)
|
||||
"""Initialize AngleEnergy."""
|
||||
validator.check_value_type('angle_numbers', angle_numbers, int, self.name)
|
||||
self.angle_numbers = angle_numbers
|
||||
self.init_prim_io_names(inputs=['uint_crd_f', 'scaler_f', 'atom_a', 'atom_b', 'atom_c', 'angle_k',
|
||||
'angle_theta0'],
|
||||
|
@ -972,7 +983,8 @@ class AngleAtomEnergy(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, angle_numbers):
|
||||
validator.check_value_type('angle_numbers', angle_numbers, (int), self.name)
|
||||
"""Initialize AngleAtomEnergy."""
|
||||
validator.check_value_type('angle_numbers', angle_numbers, int, self.name)
|
||||
self.angle_numbers = angle_numbers
|
||||
self.init_prim_io_names(inputs=['uint_crd_f', 'scaler_f', 'atom_a', 'atom_b', 'atom_c', 'angle_k',
|
||||
'angle_theta0'],
|
||||
|
@ -1043,7 +1055,8 @@ class AngleForceWithAtomEnergy(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, angle_numbers):
|
||||
validator.check_value_type('angle_numbers', angle_numbers, (int), self.name)
|
||||
"""Initialize AngleForceWithAtomEnergy."""
|
||||
validator.check_value_type('angle_numbers', angle_numbers, int, self.name)
|
||||
self.angle_numbers = angle_numbers
|
||||
self.init_prim_io_names(inputs=['uint_crd_f', 'scaler_f', 'atom_a', 'atom_b', 'atom_c', 'angle_k',
|
||||
'angle_theta0'],
|
||||
|
@ -1126,8 +1139,9 @@ class Dihedral14LJForce(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, nb14_numbers, atom_numbers):
|
||||
validator.check_value_type('nb14_numbers', nb14_numbers, (int), self.name)
|
||||
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
|
||||
"""Initialize Dihedral14LJForce."""
|
||||
validator.check_value_type('nb14_numbers', nb14_numbers, int, self.name)
|
||||
validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
|
||||
self.dihedral_14_numbers = nb14_numbers
|
||||
self.atom_numbers = atom_numbers
|
||||
self.init_prim_io_names(
|
||||
|
@ -1217,8 +1231,9 @@ class Dihedral14LJEnergy(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, nb14_numbers, atom_numbers):
|
||||
validator.check_value_type('nb14_numbers', nb14_numbers, (int), self.name)
|
||||
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
|
||||
"""Initialize Dihedral14LJEnergy"""
|
||||
validator.check_value_type('nb14_numbers', nb14_numbers, int, self.name)
|
||||
validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
|
||||
self.dihedral_14_numbers = nb14_numbers
|
||||
self.atom_numbers = atom_numbers
|
||||
|
||||
|
@ -1312,8 +1327,9 @@ class Dihedral14LJForceWithDirectCF(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, nb14_numbers, atom_numbers):
|
||||
validator.check_value_type('nb14_numbers', nb14_numbers, (int), self.name)
|
||||
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
|
||||
"""Initialize Dihedral14LJForceWithDirectCF."""
|
||||
validator.check_value_type('nb14_numbers', nb14_numbers, int, self.name)
|
||||
validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
|
||||
self.dihedral_14_numbers = nb14_numbers
|
||||
self.atom_numbers = atom_numbers
|
||||
|
||||
|
@ -1409,8 +1425,9 @@ class Dihedral14LJCFForceWithAtomEnergy(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, nb14_numbers, atom_numbers):
|
||||
validator.check_value_type('nb14_numbers', nb14_numbers, (int), self.name)
|
||||
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
|
||||
"""Initialize Dihedral14LJCFForceWithAtomEnergy."""
|
||||
validator.check_value_type('nb14_numbers', nb14_numbers, int, self.name)
|
||||
validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
|
||||
self.dihedral_14_numbers = nb14_numbers
|
||||
self.atom_numbers = atom_numbers
|
||||
|
||||
|
@ -1500,8 +1517,9 @@ class Dihedral14LJAtomEnergy(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, nb14_numbers, atom_numbers):
|
||||
validator.check_value_type('nb14_numbers', nb14_numbers, (int), self.name)
|
||||
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
|
||||
"""Initialize Dihedral14LJAtomEnergy."""
|
||||
validator.check_value_type('nb14_numbers', nb14_numbers, int, self.name)
|
||||
validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
|
||||
self.dihedral_14_numbers = nb14_numbers
|
||||
self.atom_numbers = atom_numbers
|
||||
|
||||
|
@ -1589,8 +1607,9 @@ class Dihedral14CFEnergy(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, nb14_numbers, atom_numbers):
|
||||
validator.check_value_type('nb14_numbers', nb14_numbers, (int), self.name)
|
||||
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
|
||||
"""Initialize Dihedral14CFEnergy."""
|
||||
validator.check_value_type('nb14_numbers', nb14_numbers, int, self.name)
|
||||
validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
|
||||
self.dihedral_14_numbers = nb14_numbers
|
||||
self.atom_numbers = atom_numbers
|
||||
|
||||
|
@ -1668,8 +1687,9 @@ class Dihedral14CFAtomEnergy(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, nb14_numbers, atom_numbers):
|
||||
validator.check_value_type('nb14_numbers', nb14_numbers, (int), self.name)
|
||||
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
|
||||
"""Initialize Dihedral14CFAtomEnergy."""
|
||||
validator.check_value_type('nb14_numbers', nb14_numbers, int, self.name)
|
||||
validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
|
||||
self.dihedral_14_numbers = nb14_numbers
|
||||
self.atom_numbers = atom_numbers
|
||||
|
||||
|
@ -1755,13 +1775,14 @@ class MDIterationLeapFrog(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, float4_numbers, atom_numbers, half_dt, dt, exp_gamma, is_max_velocity, max_velocity):
|
||||
validator.check_value_type('float4_numbers', float4_numbers, (int), self.name)
|
||||
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
|
||||
validator.check_value_type('half_dt', half_dt, (float), self.name)
|
||||
validator.check_value_type('dt', dt, (float), self.name)
|
||||
validator.check_value_type('exp_gamma', exp_gamma, (float), self.name)
|
||||
validator.check_value_type('is_max_velocity', is_max_velocity, (int), self.name)
|
||||
validator.check_value_type('max_velocity', max_velocity, (float), self.name)
|
||||
"""Initialize MDIterationLeapFrog."""
|
||||
validator.check_value_type('float4_numbers', float4_numbers, int, self.name)
|
||||
validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
|
||||
validator.check_value_type('half_dt', half_dt, float, self.name)
|
||||
validator.check_value_type('dt', dt, float, self.name)
|
||||
validator.check_value_type('exp_gamma', exp_gamma, float, self.name)
|
||||
validator.check_value_type('is_max_velocity', is_max_velocity, int, self.name)
|
||||
validator.check_value_type('max_velocity', max_velocity, float, self.name)
|
||||
self.float4_numbers = float4_numbers
|
||||
self.atom_numbers = atom_numbers
|
||||
self.half_dt = half_dt
|
||||
|
@ -1828,14 +1849,15 @@ class PMEReciprocalForce(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, atom_numbers, beta, fftx, ffty, fftz, box_length_0, box_length_1, box_length_2):
|
||||
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
|
||||
validator.check_value_type('beta', beta, (float), self.name)
|
||||
validator.check_value_type('fftx', fftx, (int), self.name)
|
||||
validator.check_value_type('ffty', ffty, (int), self.name)
|
||||
validator.check_value_type('fftz', fftz, (int), self.name)
|
||||
validator.check_value_type('box_length_0', box_length_0, (float), self.name)
|
||||
validator.check_value_type('box_length_1', box_length_1, (float), self.name)
|
||||
validator.check_value_type('box_length_2', box_length_2, (float), self.name)
|
||||
"""Initialize PMEReciprocalForce."""
|
||||
validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
|
||||
validator.check_value_type('beta', beta, float, self.name)
|
||||
validator.check_value_type('fftx', fftx, int, self.name)
|
||||
validator.check_value_type('ffty', ffty, int, self.name)
|
||||
validator.check_value_type('fftz', fftz, int, self.name)
|
||||
validator.check_value_type('box_length_0', box_length_0, float, self.name)
|
||||
validator.check_value_type('box_length_1', box_length_1, float, self.name)
|
||||
validator.check_value_type('box_length_2', box_length_2, float, self.name)
|
||||
self.atom_numbers = atom_numbers
|
||||
self.beta = beta
|
||||
self.fftx = fftx
|
||||
|
@ -1906,9 +1928,10 @@ class PMEExcludedForce(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, atom_numbers, excluded_numbers, beta):
|
||||
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
|
||||
validator.check_value_type('excluded_numbers', excluded_numbers, (int), self.name)
|
||||
validator.check_value_type('beta', beta, (float), self.name)
|
||||
"""Initialize PMEExcludedForce."""
|
||||
validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
|
||||
validator.check_value_type('excluded_numbers', excluded_numbers, int, self.name)
|
||||
validator.check_value_type('beta', beta, float, self.name)
|
||||
self.atom_numbers = atom_numbers
|
||||
self.excluded_numbers = excluded_numbers
|
||||
self.beta = beta
|
||||
|
@ -1999,15 +2022,16 @@ class PMEEnergy(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self, atom_numbers, excluded_numbers, beta, fftx, ffty, fftz, box_length_0, box_length_1,
|
||||
box_length_2):
|
||||
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
|
||||
validator.check_value_type('excluded_numbers', excluded_numbers, (int), self.name)
|
||||
validator.check_value_type('beta', beta, (float), self.name)
|
||||
validator.check_value_type('fftx', fftx, (int), self.name)
|
||||
validator.check_value_type('ffty', ffty, (int), self.name)
|
||||
validator.check_value_type('fftz', fftz, (int), self.name)
|
||||
validator.check_value_type('box_length_0', box_length_0, (float), self.name)
|
||||
validator.check_value_type('box_length_1', box_length_1, (float), self.name)
|
||||
validator.check_value_type('box_length_2', box_length_2, (float), self.name)
|
||||
"""Initialize PMEEnergy."""
|
||||
validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
|
||||
validator.check_value_type('excluded_numbers', excluded_numbers, int, self.name)
|
||||
validator.check_value_type('beta', beta, float, self.name)
|
||||
validator.check_value_type('fftx', fftx, int, self.name)
|
||||
validator.check_value_type('ffty', ffty, int, self.name)
|
||||
validator.check_value_type('fftz', fftz, int, self.name)
|
||||
validator.check_value_type('box_length_0', box_length_0, float, self.name)
|
||||
validator.check_value_type('box_length_1', box_length_1, float, self.name)
|
||||
validator.check_value_type('box_length_2', box_length_2, float, self.name)
|
||||
self.atom_numbers = atom_numbers
|
||||
self.excluded_numbers = excluded_numbers
|
||||
self.beta = beta
|
||||
|
@ -2113,8 +2137,9 @@ class LJEnergy(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, atom_numbers, cutoff_square):
|
||||
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
|
||||
validator.check_value_type('cutoff_square', cutoff_square, (float), self.name)
|
||||
"""Initialize LJEnergy."""
|
||||
validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
|
||||
validator.check_value_type('cutoff_square', cutoff_square, float, self.name)
|
||||
self.atom_numbers = atom_numbers
|
||||
self.cutoff_square = cutoff_square
|
||||
self.init_prim_io_names(
|
||||
|
@ -2199,8 +2224,9 @@ class LJForce(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, atom_numbers, cutoff_square):
|
||||
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
|
||||
validator.check_value_type('cutoff_square', cutoff_square, (float), self.name)
|
||||
"""Initialize LJForce."""
|
||||
validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
|
||||
validator.check_value_type('cutoff_square', cutoff_square, float, self.name)
|
||||
self.atom_numbers = atom_numbers
|
||||
self.cutoff_square = cutoff_square
|
||||
self.init_prim_io_names(
|
||||
|
@ -2280,9 +2306,10 @@ class LJForceWithPMEDirectForce(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, atom_numbers, cutoff, pme_beta):
|
||||
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
|
||||
validator.check_value_type('cutoff', cutoff, (float), self.name)
|
||||
validator.check_value_type('pme_beta', pme_beta, (float), self.name)
|
||||
"""Initialize LJForceWithPMEDirectForce."""
|
||||
validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
|
||||
validator.check_value_type('cutoff', cutoff, float, self.name)
|
||||
validator.check_value_type('pme_beta', pme_beta, float, self.name)
|
||||
self.atom_numbers = atom_numbers
|
||||
self.cutoff = cutoff
|
||||
self.pme_beta = pme_beta
|
||||
|
@ -2340,8 +2367,9 @@ class GetCenterOfGeometry(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, center_numbers, center_numbers_inverse):
|
||||
validator.check_value_type('center_numbers', center_numbers, (int), self.name)
|
||||
validator.check_value_type('center_numbers_inverse', center_numbers_inverse, (float), self.name)
|
||||
"""Initialize GetCenterOfGeometry."""
|
||||
validator.check_value_type('center_numbers', center_numbers, int, self.name)
|
||||
validator.check_value_type('center_numbers_inverse', center_numbers_inverse, float, self.name)
|
||||
self.center_numbers = center_numbers
|
||||
self.center_numbers_inverse = center_numbers_inverse
|
||||
self.add_prim_attr('center_numbers', self.center_numbers)
|
||||
|
@ -2377,8 +2405,9 @@ class MDTemperature(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, residue_numbers, atom_numbers):
|
||||
validator.check_value_type('residue_numbers', residue_numbers, (int), self.name)
|
||||
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
|
||||
"""Initialize MDTemperature."""
|
||||
validator.check_value_type('residue_numbers', residue_numbers, int, self.name)
|
||||
validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
|
||||
self.residue_numbers = residue_numbers
|
||||
self.atom_numbers = atom_numbers
|
||||
self.add_prim_attr('residue_numbers', self.residue_numbers)
|
||||
|
@ -2423,8 +2452,8 @@ class NeighborListUpdate(PrimitiveWithInfer):
|
|||
list first time or not.
|
||||
nxy(int32): the total number of grids divided in xy plane.
|
||||
excluded_atom_numbers(int32): the total atom numbers in the excluded list.
|
||||
cutoff(float32): the cutoff distance for short-range force calculation.
|
||||
skin(float32): the overflow value of cutoff to maintain a neighbor list.
|
||||
cutoff(float32): the cutoff distance for short-range force calculation. Default: 10.0.
|
||||
skin(float32): the overflow value of cutoff to maintain a neighbor list. Default: 2.0.
|
||||
cutoff_square(float32): the suqare value of cutoff.
|
||||
half_skin_square(float32): skin*skin/4, indicates the maximum
|
||||
square value of the distance atom allowed to move between two updates.
|
||||
|
@ -2432,8 +2461,9 @@ class NeighborListUpdate(PrimitiveWithInfer):
|
|||
radius of the neighbor list for each atom.
|
||||
half_cutoff_with_skin(float32): cutoff_with_skin/2.
|
||||
cutoff_with_skin_square(float32): the square value of cutoff_with_skin.
|
||||
refresh_interval(int32): the number of iteration steps between two updates of neighbor list.
|
||||
max_atom_in_grid_numbers(int32): the maximum number of atoms in one grid.
|
||||
refresh_interval(int32): the number of iteration steps between two updates of neighbor list. Default: 20.
|
||||
max_atom_in_grid_numbers(int32): the maximum number of atoms in one grid. Default: 64.
|
||||
max_neighbor_numbers(int32): The maximum number of neighbors. Default: 800.
|
||||
|
||||
Inputs:
|
||||
- **atom_numbers_in_grid_bucket** (Tensor, int32) - [G,], the number of atoms in each grid bucket.
|
||||
|
@ -2470,6 +2500,7 @@ class NeighborListUpdate(PrimitiveWithInfer):
|
|||
def __init__(self, grid_numbers, atom_numbers, not_first_time, nxy, excluded_atom_numbers,
|
||||
cutoff_square, half_skin_square, cutoff_with_skin, half_cutoff_with_skin, cutoff_with_skin_square,
|
||||
refresh_interval=20, cutoff=10.0, skin=2.0, max_atom_in_grid_numbers=64, max_neighbor_numbers=800):
|
||||
"""Initialize NeighborListUpdate."""
|
||||
self.grid_numbers = grid_numbers
|
||||
self.atom_numbers = atom_numbers
|
||||
self.refresh_interval = refresh_interval
|
||||
|
@ -2641,13 +2672,14 @@ class MDIterationLeapFrogWithRF(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, float4_numbers, atom_numbers, half_dt, dt, exp_gamma, is_max_velocity, max_velocity):
|
||||
validator.check_value_type('float4_numbers', float4_numbers, (int), self.name)
|
||||
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
|
||||
validator.check_value_type('half_dt', half_dt, (float), self.name)
|
||||
validator.check_value_type('dt', dt, (float), self.name)
|
||||
validator.check_value_type('exp_gamma', exp_gamma, (float), self.name)
|
||||
validator.check_value_type('is_max_velocity', is_max_velocity, (int), self.name)
|
||||
validator.check_value_type('max_velocity', max_velocity, (float), self.name)
|
||||
"""Initialize MDIterationLeapFrogWithRF."""
|
||||
validator.check_value_type('float4_numbers', float4_numbers, int, self.name)
|
||||
validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
|
||||
validator.check_value_type('half_dt', half_dt, float, self.name)
|
||||
validator.check_value_type('dt', dt, float, self.name)
|
||||
validator.check_value_type('exp_gamma', exp_gamma, float, self.name)
|
||||
validator.check_value_type('is_max_velocity', is_max_velocity, int, self.name)
|
||||
validator.check_value_type('max_velocity', max_velocity, float, self.name)
|
||||
self.float4_numbers = float4_numbers
|
||||
self.atom_numbers = atom_numbers
|
||||
self.half_dt = half_dt
|
||||
|
@ -2740,6 +2772,7 @@ class MDIterationLeapFrogLiujian(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, atom_numbers, half_dt, dt, exp_gamma):
|
||||
"""Initialize MDIterationLeapFrogLiujian."""
|
||||
self.atom_numbers = atom_numbers
|
||||
self.half_dt = half_dt
|
||||
self.dt = dt
|
||||
|
@ -2792,6 +2825,7 @@ class CrdToUintCrd(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, atom_numbers):
|
||||
"""Initialize CrdToUintCrd."""
|
||||
self.atom_numbers = atom_numbers
|
||||
self.add_prim_attr('atom_numbers', self.atom_numbers)
|
||||
self.init_prim_io_names(
|
||||
|
@ -2827,6 +2861,7 @@ class MDIterationSetupRandState(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, atom_numbers, seed):
|
||||
"""Initialize MDIterationSetupRandState."""
|
||||
self.atom_numbers = atom_numbers
|
||||
self.seed = seed
|
||||
self.add_prim_attr('atom_numbers', self.atom_numbers)
|
||||
|
@ -2873,10 +2908,11 @@ class TransferCrd(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, start_serial, end_serial, number, atom_numbers):
|
||||
validator.check_value_type('start_serial', start_serial, (int), self.name)
|
||||
validator.check_value_type('end_serial', end_serial, (int), self.name)
|
||||
validator.check_value_type('number', number, (int), self.name)
|
||||
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
|
||||
"""Initialize TransferCrd."""
|
||||
validator.check_value_type('start_serial', start_serial, int, self.name)
|
||||
validator.check_value_type('end_serial', end_serial, int, self.name)
|
||||
validator.check_value_type('number', number, int, self.name)
|
||||
validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
|
||||
self.start_serial = start_serial
|
||||
self.end_serial = end_serial
|
||||
self.number = number
|
||||
|
|
|
@ -29,6 +29,7 @@ from mindspore.ops import operations as P
|
|||
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
|
||||
from mindspore.ops.primitive import constexpr, PrimitiveWithInfer, prim_attr_register
|
||||
from mindspore.ops._grad.grad_base import bprop_getters
|
||||
from mindspore.ops._utils.utils import generate_shape_index
|
||||
from mindspore import Tensor, RowTensor, context
|
||||
from mindspore.common.parameter import Parameter, ParameterTuple
|
||||
from mindspore.common import dtype as mstype
|
||||
|
@ -89,17 +90,6 @@ class MindDataSet(MindData):
|
|||
return tuple(lst)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _generate_shape_index(out_shape, indices_shape, axis):
|
||||
out_rank = len(out_shape)
|
||||
ind_rank = len(indices_shape)
|
||||
if axis < 0:
|
||||
axis += out_rank - ind_rank + 1
|
||||
perm_part1 = tuple(range(axis, axis + ind_rank))
|
||||
index = tuple(range(out_rank))
|
||||
perm = perm_part1 + index[:axis] + index[axis + ind_rank:]
|
||||
return perm
|
||||
|
||||
@constexpr
|
||||
def _generate_inverse_index(x_shape, axis):
|
||||
x_rank = len(x_shape)
|
||||
|
@ -155,7 +145,7 @@ def get_bprop_sparse_gather_v2(self):
|
|||
out_shp = shape_op(dout)
|
||||
ind_shp = shape_op(indices)
|
||||
# Example: out_shape:(3,2,3) axis 1 -> (1,0,2)
|
||||
perm_1 = _generate_shape_index(out_shp, ind_shp, axis)
|
||||
perm_1 = generate_shape_index(out_shp, ind_shp, axis)
|
||||
values_transpose = transpose(dout, perm_1)
|
||||
params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis])
|
||||
# Example: out_shape:(3,2,3) axis 2 -> (1,2,0)
|
||||
|
|
Loading…
Reference in New Issue