!18487 remove the redundant code, add docstring of operator init and add default valuefor args.

Merge pull request !18487 from wangshuide/wsd_master_warn
This commit is contained in:
i-robot 2021-06-18 09:38:26 +00:00 committed by Gitee
commit f6cda391a9
47 changed files with 575 additions and 483 deletions

View File

@ -85,6 +85,7 @@ class Softmax(Cell):
"""
def __init__(self, axis=-1):
"""Initialize Softmax."""
super(Softmax, self).__init__()
self.softmax = P.Softmax(axis)
@ -135,6 +136,7 @@ class LogSoftmax(Cell):
"""
def __init__(self, axis=-1):
"""Initialize LogSoftmax."""
super(LogSoftmax, self).__init__()
self.log_softmax = P.LogSoftmax(axis)
@ -185,6 +187,7 @@ class ELU(Cell):
"""
def __init__(self, alpha=1.0):
"""Initialize ELU."""
super(ELU, self).__init__()
self.elu = P.Elu(alpha)
@ -229,6 +232,7 @@ class ReLU(Cell):
"""
def __init__(self):
"""Initialize ReLU."""
super(ReLU, self).__init__()
self.relu = P.ReLU()
@ -271,6 +275,7 @@ class ReLU6(Cell):
"""
def __init__(self):
"""Initialize ReLU6."""
super(ReLU6, self).__init__()
self.relu6 = P.ReLU6()
@ -316,6 +321,7 @@ class LeakyReLU(Cell):
"""
def __init__(self, alpha=0.2):
"""Initialize LeakyReLU."""
super(LeakyReLU, self).__init__()
validator.check_value_type('alpha', alpha, [float, int], self.cls_name)
self.greater_equal = P.GreaterEqual()
@ -366,6 +372,7 @@ class Tanh(Cell):
"""
def __init__(self):
"""Initialize Tanh."""
super(Tanh, self).__init__()
self.tanh = P.Tanh()
@ -413,6 +420,7 @@ class GELU(Cell):
"""
def __init__(self):
"""Initialize GELU."""
super(GELU, self).__init__()
self.gelu = P.GeLU()
@ -456,6 +464,7 @@ class FastGelu(Cell):
"""
def __init__(self):
"""Initialize FastGelu."""
super(FastGelu, self).__init__()
self.fast_gelu = P.FastGeLU()
@ -501,6 +510,7 @@ class Sigmoid(Cell):
"""
def __init__(self):
"""Initialize Sigmoid."""
super(Sigmoid, self).__init__()
self.sigmoid = P.Sigmoid()
@ -560,6 +570,7 @@ class PReLU(Cell):
"""
@cell_attr_register(attrs="")
def __init__(self, channel=1, w=0.25):
"""Initialize PReLU."""
super(PReLU, self).__init__()
validator.check_positive_int(channel, 'channel', self.cls_name)
if isinstance(w, (np.float32, float)):
@ -619,6 +630,7 @@ class HSwish(Cell):
"""
def __init__(self):
"""Initialize HSwish."""
super(HSwish, self).__init__()
self.hswish = P.HSwish()
@ -660,6 +672,7 @@ class HSigmoid(Cell):
"""
def __init__(self):
"""Initialize HSigmoid."""
super(HSigmoid, self).__init__()
self.hsigmoid = P.HSigmoid()
@ -701,6 +714,7 @@ class LogSigmoid(Cell):
"""
def __init__(self):
"""Initialize LogSigmoid."""
super(LogSigmoid, self).__init__()
self.mul = P.Mul()
self.exp = P.Exp()

View File

@ -76,6 +76,7 @@ class L1Regularizer(Cell):
"""
def __init__(self, scale):
"""Initialize L1Regularizer."""
super(L1Regularizer, self).__init__()
Validator.check_value_type("scale", scale, [int, float], self.cls_name)
if scale <= 0:
@ -143,6 +144,7 @@ class Dropout(Cell):
"""
def __init__(self, keep_prob=0.5, dtype=mstype.float32):
"""Initialize Dropout."""
super(Dropout, self).__init__()
if keep_prob <= 0 or keep_prob > 1:
raise ValueError("dropout probability should be a number in range (0, 1], but got {}".format(keep_prob))
@ -197,6 +199,7 @@ class Flatten(Cell):
"""
def __init__(self):
"""Initialize Flatten."""
super(Flatten, self).__init__()
def construct(self, x):
@ -269,6 +272,7 @@ class Dense(Cell):
bias_init='zeros',
has_bias=True,
activation=None):
"""Initialize Dense."""
super(Dense, self).__init__()
self.in_channels = Validator.check_positive_int(in_channels)
self.out_channels = Validator.check_positive_int(out_channels)
@ -276,7 +280,6 @@ class Dense(Cell):
self.reshape = P.Reshape()
self.shape_op = P.Shape()
if isinstance(weight_init, Tensor):
if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \
weight_init.shape[1] != in_channels:
@ -390,6 +393,7 @@ class ClipByNorm(Cell):
"""
def __init__(self, axis=None):
"""Initialize ClipByNorm."""
super(ClipByNorm, self).__init__()
if axis is None:
axis = ()
@ -468,6 +472,7 @@ class Norm(Cell):
"""
def __init__(self, axis=(), keep_dims=False):
"""Initialize Norm."""
super(Norm, self).__init__()
Validator.check_value_type("keep_dims", keep_dims, [bool], self.cls_name)
self.axis = axis
@ -561,6 +566,7 @@ class OneHot(Cell):
"""
def __init__(self, axis=-1, depth=1, on_value=1.0, off_value=0.0, dtype=mstype.float32):
"""Initialize OneHot."""
super(OneHot, self).__init__()
self.onehot = P.OneHot(axis)
self.depth = depth
@ -633,6 +639,7 @@ class Pad(Cell):
"""
def __init__(self, paddings, mode="CONSTANT"):
"""Initialize Pad."""
super(Pad, self).__init__()
self.mode = mode
self.paddings = paddings
@ -722,6 +729,7 @@ class ResizeBilinear(Cell):
"""
def __init__(self):
"""Initialize ResizeBilinear."""
super(ResizeBilinear, self).__init__()
def construct(self, x, size=None, scale_factor=None, align_corners=False):
@ -780,6 +788,7 @@ class Unfold(Cell):
"""
def __init__(self, ksizes, strides, rates, padding="valid"):
"""Initialize Unfold."""
super(Unfold, self).__init__()
def _check_tuple_or_list(arg_name, arg_val, prim_name):
@ -841,6 +850,7 @@ class Tril(Cell):
"""
def __init__(self):
"""Initialize Tril."""
super(Tril, self).__init__()
self.dtype = P.DType()
self.mul = P.Mul()
@ -888,6 +898,7 @@ class Triu(Cell):
"""
def __init__(self):
"""Initialize Triu."""
super(Triu, self).__init__()
self.dtype = P.DType()
self.mul = P.Mul()
@ -946,6 +957,7 @@ class MatrixDiag(Cell):
"""
def __init__(self):
"""Initialize MatrixDiag."""
super(MatrixDiag, self).__init__()
self.matrix_diag = inner.MatrixDiag()
self.dtype = P.DType()
@ -990,6 +1002,7 @@ class MatrixDiagPart(Cell):
"""
def __init__(self):
"""Initialize MatrixDiagPart."""
super(MatrixDiagPart, self).__init__()
self.matrix_diag_part = inner.MatrixDiagPart()
self.dtype = P.DType()
@ -1048,6 +1061,7 @@ class MatrixSetDiag(Cell):
"""
def __init__(self):
"""Initialize MatrixSetDiag."""
super(MatrixSetDiag, self).__init__()
self.matrix_set_diag = inner.MatrixSetDiag()
self.dtype = P.DType()

View File

@ -112,6 +112,7 @@ class Conv2dBnAct(Cell):
activation=None,
alpha=0.2,
after_fake=True):
"""Initialize Conv2dBnAct."""
super(Conv2dBnAct, self).__init__()
self.conv = nn.Conv2d(in_channels,
@ -206,6 +207,7 @@ class DenseBnAct(Cell):
activation=None,
alpha=0.2,
after_fake=True):
"""Initialize DenseBnAct."""
super(DenseBnAct, self).__init__()
self.dense = nn.Dense(
in_channels,

View File

@ -72,7 +72,7 @@ def _get_prefix_and_index(cells):
return prefix, index
class _CellListBase():
class _CellListBase:
"""
An interface for base the cell as list.
@ -84,6 +84,7 @@ class _CellListBase():
by iterator or subscript , it will be interpreted as a list of cells.
"""
def __init__(self):
"""Initialize _CellListBase."""
self.__cell_as_list__ = True
@abstractmethod
@ -133,6 +134,7 @@ class SequentialCell(Cell):
[27. 27.]]]]
"""
def __init__(self, *args):
"""Initialize SequentialCell."""
super(SequentialCell, self).__init__()
self._is_dynamic_name = []
if len(args) == 1:
@ -270,6 +272,7 @@ class CellList(_CellListBase, Cell):
>
"""
def __init__(self, *args, **kwargs):
"""Initialize CellList."""
auto_prefix = kwargs["auto_prefix"] if "auto_prefix" in kwargs.keys() else True
_CellListBase.__init__(self)
Cell.__init__(self, auto_prefix)

View File

@ -47,6 +47,7 @@ class _Conv(Cell):
bias_init,
data_format='NCHW',
transposed=False):
"""Initialize _Conv."""
super(_Conv, self).__init__()
self.in_channels = Validator.check_positive_int(in_channels)
self.out_channels = Validator.check_positive_int(out_channels)
@ -207,8 +208,8 @@ class Conv2d(_Conv):
Examples:
>>> net = nn.Conv2d(120, 240, 4, has_bias=False, weight_init='normal')
>>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)
>>> output = net(input).shape
>>> x = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)
>>> output = net(x).shape
>>> print(output)
(1, 240, 1024, 640)
"""
@ -227,6 +228,7 @@ class Conv2d(_Conv):
weight_init='normal',
bias_init='zeros',
data_format='NCHW'):
"""Initialize Conv2d."""
kernel_size = twice(kernel_size)
stride = twice(stride)
self._dilation = dilation
@ -372,8 +374,8 @@ class Conv1d(_Conv):
Examples:
>>> net = nn.Conv1d(120, 240, 4, has_bias=False, weight_init='normal')
>>> input = Tensor(np.ones([1, 120, 640]), mindspore.float32)
>>> output = net(input).shape
>>> x = Tensor(np.ones([1, 120, 640]), mindspore.float32)
>>> output = net(x).shape
>>> print(output)
(1, 240, 640)
"""
@ -391,7 +393,7 @@ class Conv1d(_Conv):
has_bias=False,
weight_init='normal',
bias_init='zeros'):
"""Initialize Conv1d."""
Validator.check_value_type("kernel_size", kernel_size, [int], self.cls_name)
Validator.check_value_type("stride", stride, [int], self.cls_name)
Validator.check_value_type("padding", padding, [int], self.cls_name)
@ -575,9 +577,9 @@ class Conv3d(_Conv):
``Ascend``
Examples:
>>> input = Tensor(np.ones([16, 3, 10, 32, 32]), mindspore.float32)
>>> x = Tensor(np.ones([16, 3, 10, 32, 32]), mindspore.float32)
>>> conv3d = nn.Conv3d(in_channels=3, out_channels=32, kernel_size=(4, 3, 3))
>>> output = conv3d(input)
>>> output = conv3d(x)
>>> print(output.shape)
(16, 32, 10, 32, 32)
"""
@ -596,6 +598,7 @@ class Conv3d(_Conv):
weight_init='normal',
bias_init='zeros',
data_format='NCDHW'):
"""Initialize Conv3d."""
kernel_size = _check_3d_int_or_tuple("kernel_size", kernel_size, self.cls_name)
stride = _check_3d_int_or_tuple("stride", stride, self.cls_name)
dilation = _check_3d_int_or_tuple("dilation", dilation, self.cls_name)
@ -746,10 +749,10 @@ class Conv3dTranspose(_Conv):
ValueError: If `data_format` is not 'NCDHW'.
Examples:
>>> input = Tensor(np.ones([32, 16, 10, 32, 32]), mindspore.float32)
>>> x = Tensor(np.ones([32, 16, 10, 32, 32]), mindspore.float32)
>>> conv3d_transpose = nn.Conv3dTranspose(in_channels=16, out_channels=3, kernel_size=(4, 6, 2),
... pad_mode='pad')
>>> output = conv3d_transpose(input)
>>> output = conv3d_transpose(x)
>>> print(output.shape)
(32, 3, 13, 37, 33)
"""
@ -768,6 +771,7 @@ class Conv3dTranspose(_Conv):
weight_init='normal',
bias_init='zeros',
data_format='NCDHW'):
"""Initialize Conv3dTranspose."""
kernel_size = _check_3d_int_or_tuple("kernel_size", kernel_size, self.cls_name)
stride = _check_3d_int_or_tuple("stride", stride, self.cls_name)
dilation = _check_3d_int_or_tuple("dilation", dilation, self.cls_name)
@ -929,8 +933,8 @@ class Conv2dTranspose(_Conv):
Examples:
>>> net = nn.Conv2dTranspose(3, 64, 4, has_bias=False, weight_init='normal', pad_mode='pad')
>>> input = Tensor(np.ones([1, 3, 16, 50]), mindspore.float32)
>>> output = net(input).shape
>>> x = Tensor(np.ones([1, 3, 16, 50]), mindspore.float32)
>>> output = net(x).shape
>>> print(output)
(1, 64, 19, 53)
"""
@ -947,6 +951,7 @@ class Conv2dTranspose(_Conv):
has_bias=False,
weight_init='normal',
bias_init='zeros'):
"""Initialize Conv2dTranspose."""
kernel_size = twice(kernel_size)
stride = twice(stride)
dilation = twice(dilation)
@ -1098,8 +1103,8 @@ class Conv1dTranspose(_Conv):
Examples:
>>> net = nn.Conv1dTranspose(3, 64, 4, has_bias=False, weight_init='normal', pad_mode='pad')
>>> input = Tensor(np.ones([1, 3, 50]), mindspore.float32)
>>> output = net(input).shape
>>> x = Tensor(np.ones([1, 3, 50]), mindspore.float32)
>>> output = net(x).shape
>>> print(output)
(1, 64, 53)
"""
@ -1116,6 +1121,7 @@ class Conv1dTranspose(_Conv):
has_bias=False,
weight_init='normal',
bias_init='zeros'):
"""Initialize Conv1dTranspose."""
Validator.check_value_type("kernel_size", kernel_size, [int], self.cls_name)
Validator.check_value_type("stride", stride, [int], self.cls_name)
Validator.check_value_type("padding", padding, [int], self.cls_name)

View File

@ -98,6 +98,7 @@ class Embedding(Cell):
def __init__(self, vocab_size, embedding_size, use_one_hot=False, embedding_table='normal',
dtype=mstype.float32, padding_idx=None):
"""Initialize Embedding."""
super(Embedding, self).__init__()
self.vocab_size = validator.check_value_type('vocab_size', vocab_size, [int], self.cls_name)
self.embedding_size = validator.check_value_type('embedding_size', embedding_size, [int], self.cls_name)
@ -223,6 +224,7 @@ class EmbeddingLookup(Cell):
def __init__(self, vocab_size, embedding_size, param_init='normal',
target='CPU', slice_mode='batch_slice', manual_shapes=None,
max_norm=None, sparse=True, vocab_cache_size=0):
"""Initialize EmbeddingLookup."""
super(EmbeddingLookup, self).__init__()
validator.check_value_type('sparse', sparse, [bool], self.cls_name)
self.vocab_size = validator.check_positive_int(vocab_size, 'vocab_size')
@ -416,11 +418,12 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
['DEVICE', 'CPU']. Default: 'CPU'.
slice_mode (str): The slicing way in semi_auto_parallel/auto_parallel. The value must get through
nn.EmbeddingLookup. Default: nn.EmbeddingLookup.BATCH_SLICE.
feature_num_list (tuple): The accompaniment array in field slice mode. This is unused currently.
feature_num_list (tuple): The accompaniment array in field slice mode. This is unused currently. Default: None.
max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32
or None. Default: None
sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true. Default: True.
operator (str): The pooling method for the features in one field. Support 'SUM, 'MEAN' and 'MAX'
operator (str): The pooling method for the features in one field. Support 'SUM', 'MEAN' and 'MAX'.
Default: 'SUM'.
Inputs:
- **input_indices** (Tensor) - The shape of tensor is :math:`(batch\_size, seq\_length)`.
@ -464,6 +467,7 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
def __init__(self, vocab_size, embedding_size, field_size, param_init='normal', target='CPU',
slice_mode='batch_slice', feature_num_list=None, max_norm=None, sparse=True, operator='SUM'):
"""Initialize MultiFieldEmbeddingLookup."""
super(MultiFieldEmbeddingLookup, self).__init__(vocab_size, embedding_size, param_init, target,
slice_mode, feature_num_list, max_norm, sparse)
self.field_size = validator.check_positive_int(field_size, 'field_size')

View File

@ -120,10 +120,10 @@ class LSTM(Cell):
Examples:
>>> net = nn.LSTM(10, 16, 2, has_bias=True, batch_first=True, bidirectional=False)
>>> input = Tensor(np.ones([3, 5, 10]).astype(np.float32))
>>> x = Tensor(np.ones([3, 5, 10]).astype(np.float32))
>>> h0 = Tensor(np.ones([1 * 2, 3, 16]).astype(np.float32))
>>> c0 = Tensor(np.ones([1 * 2, 3, 16]).astype(np.float32))
>>> output, (hn, cn) = net(input, (h0, c0))
>>> output, (hn, cn) = net(x, (h0, c0))
>>> print(output.shape)
(3, 5, 16)
"""
@ -136,6 +136,7 @@ class LSTM(Cell):
batch_first=False,
dropout=0,
bidirectional=False):
"""Initialize LSTM."""
super(LSTM, self).__init__()
validator.check_value_type("batch_first", batch_first, [bool], self.cls_name)
validator.check_positive_int(hidden_size, "hidden_size", self.cls_name)
@ -360,11 +361,11 @@ class LSTMCell(Cell):
Examples:
>>> net = nn.LSTMCell(10, 12, has_bias=True, batch_first=True, bidirectional=False)
>>> input = Tensor(np.ones([3, 5, 10]).astype(np.float32))
>>> x = Tensor(np.ones([3, 5, 10]).astype(np.float32))
>>> h = Tensor(np.ones([1, 3, 12]).astype(np.float32))
>>> c = Tensor(np.ones([1, 3, 12]).astype(np.float32))
>>> w = Tensor(np.ones([1152, 1, 1]).astype(np.float32))
>>> output, h, c, _, _ = net(input, h, c, w)
>>> output, h, c, _, _ = net(x, h, c, w)
>>> print(output.shape)
(3, 5, 12)
"""
@ -376,6 +377,7 @@ class LSTMCell(Cell):
batch_first=False,
dropout=0,
bidirectional=False):
"""Initialize LSTMCell."""
super(LSTMCell, self).__init__()
self.batch_first = validator.check_value_type("batch_first", batch_first, [bool], self.cls_name)
self.transpose = P.Transpose()

View File

@ -99,6 +99,7 @@ class ReduceLogSumExp(Cell):
"""
def __init__(self, axis, keep_dims=False):
"""Initialize ReduceLogSumExp."""
super(ReduceLogSumExp, self).__init__()
validator.check_value_type('axis', axis, [int, list, tuple], self.cls_name)
validator.check_value_type('keep_dims', keep_dims, [bool], self.cls_name)
@ -129,7 +130,7 @@ class Range(Cell):
start (Union[int, float]): If `limit` is `None`, the value acts as limit in the range and first entry
defaults to `0`. Otherwise, it acts as first entry in the range.
limit (Union[int, float]): Acts as upper limit of sequence. If `None`, defaults to the value of `start`
while set the first entry of the range to `0`. It can not be equal to `start`.
while set the first entry of the range to `0`. It can not be equal to `start`. Default: None.
delta (Union[int, float]): Increment of the range. It can not be equal to zero. Default: 1.
Outputs:
@ -146,6 +147,7 @@ class Range(Cell):
"""
def __init__(self, start, limit=None, delta=1):
"""Initialize Range."""
super(Range, self).__init__()
if delta == 0:
raise ValueError("The input of `delta` can not be equal to zero.")
@ -211,6 +213,7 @@ class LGamma(Cell):
"""
def __init__(self):
"""Initialize LGamma."""
super(LGamma, self).__init__()
# const numbers
self.k_lanczos_gamma = 7
@ -258,7 +261,7 @@ class LGamma(Cell):
for i in range(8):
product_ = k_lanczos_coefficients[i] / (z + i + 1)
reflex_x = product_ + reflex_x
return reflex_x
return reflex_x
reflex_x = _calculate_reflected_x(z, self.k_base_lanczos_coeff, self.k_lanczos_coefficients)
t = z + self.lanczos_gamma_plus_one_half
@ -322,6 +325,7 @@ class DiGamma(Cell):
"""
def __init__(self):
"""Initialize DiGamma."""
super(DiGamma, self).__init__()
# const numbers
self.k_lanczos_gamma = 7
@ -360,7 +364,7 @@ class DiGamma(Cell):
for i in range(8):
num = num - k_lanczos_coefficients[i] / ((z + i + 1) * (z + i + 1))
denom = denom + k_lanczos_coefficients[i] / (z + i + 1)
return num, denom
return num, denom
num, denom = _calculate_num_denom(z, self.k_base_lanczos_coeff, self.k_lanczos_coefficients)
t = z + self.lanczos_gamma_plus_one_half
@ -586,6 +590,7 @@ class IGamma(Cell):
"""
def __init__(self):
"""Initialize IGamma."""
super(IGamma, self).__init__()
# const numbers
# If more data types are supported, this float max value need to be selected.
@ -674,6 +679,7 @@ class LBeta(Cell):
"""
def __init__(self):
"""Initialize LBeta."""
super(LBeta, self).__init__()
# const numbers
self.log_2pi = np.log(2 * np.pi)
@ -855,6 +861,7 @@ class MatMul(Cell):
@deprecated('1.2', 'ops.matmul', False)
def __init__(self, transpose_x1=False, transpose_x2=False):
"""Initialize MatMul."""
super(MatMul, self).__init__()
validator.check_value_type('transpose_x1', transpose_x1, [bool], self.cls_name)
@ -906,9 +913,9 @@ class Moments(Cell):
Calculates the mean and variance of `x`.
Args:
axis (Union[int, tuple(int)]): Calculates the mean and variance along the specified axis. Default: ().
axis (Union[int, tuple(int)]): Calculates the mean and variance along the specified axis. Default: None.
keep_dims (bool): If true, The dimension of mean and variance are identical with input's.
If false, don't keep these dimensions. Default: False.
If false, don't keep these dimensions. Default: None.
Inputs:
- **input_x** (Tensor) - The tensor to be calculated. Only float16 and float32 are supported.
@ -938,6 +945,7 @@ class Moments(Cell):
"""
def __init__(self, axis=None, keep_dims=None):
"""Initialize Moments."""
super(Moments, self).__init__()
if axis is None:
axis = ()
@ -997,6 +1005,7 @@ class MatInverse(Cell):
[2.1111116 -0.5555557 0.11111111]]
"""
def __init__(self):
"""Initialize MatInverse."""
super(MatInverse, self).__init__()
self.dtype = P.DType()
self.choleskytrsm = P.CholeskyTrsm()
@ -1035,6 +1044,7 @@ class MatDet(Cell):
35.999996
"""
def __init__(self):
"""Initialize MatDet."""
super(MatDet, self).__init__()
self.dtype = P.DType()
self.cholesky = P.Cholesky()

View File

@ -58,6 +58,7 @@ class _BatchNorm(Cell):
process_groups=0,
input_dims='2d',
data_format='NCHW'):
"""Initialize _BatchNorm."""
super(_BatchNorm, self).__init__()
validator.check_value_type('num_features', num_features, [int], self.cls_name)
if num_features < 1:
@ -339,6 +340,7 @@ class BatchNorm1d(_BatchNorm):
moving_mean_init='zeros',
moving_var_init='ones',
use_batch_statistics=None):
"""Initialize BatchNorm1d."""
super(BatchNorm1d, self).__init__(num_features,
eps,
momentum,
@ -447,6 +449,7 @@ class BatchNorm2d(_BatchNorm):
moving_var_init='ones',
use_batch_statistics=None,
data_format='NCHW'):
"""Initialize BatchNorm2d."""
super(BatchNorm2d, self).__init__(num_features,
eps,
momentum,
@ -546,6 +549,7 @@ class BatchNorm3d(Cell):
moving_var_init='ones',
use_batch_statistics=None,
data_format='NCDHW'):
"""Initialize BatchNorm3d."""
super(BatchNorm3d, self).__init__()
self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.cls_name)
self.reshape = P.Reshape()
@ -664,6 +668,7 @@ class GlobalBatchNorm(_BatchNorm):
moving_var_init='ones',
use_batch_statistics=None,
device_num_each_group=2):
"""Initialize GlobalBatchNorm."""
super(GlobalBatchNorm, self).__init__(num_features,
eps,
momentum,
@ -782,6 +787,7 @@ class SyncBatchNorm(_BatchNorm):
moving_var_init='ones',
use_batch_statistics=None,
process_groups=None):
"""Initialize SyncBatchNorm."""
super(SyncBatchNorm, self).__init__(num_features,
eps,
momentum,
@ -861,6 +867,7 @@ class LayerNorm(Cell):
beta_init='zeros',
epsilon=1e-7
):
"""Initialize LayerNorm."""
super(LayerNorm, self).__init__()
if not isinstance(normalized_shape, (tuple, list)):
raise TypeError("The type of 'normalized_shape' should be tuple[int] or list[int], but '{}' type is {}."
@ -965,6 +972,7 @@ class InstanceNorm2d(Cell):
affine=True,
gamma_init='ones',
beta_init='zeros'):
"""Initialize InstanceNorm2d."""
super(InstanceNorm2d, self).__init__()
validator.check_value_type('num_features', num_features, [int], self.cls_name)
validator.check_value_type('eps', eps, [float], self.cls_name)
@ -1073,6 +1081,7 @@ class GroupNorm(Cell):
"""
def __init__(self, num_groups, num_channels, eps=1e-05, affine=True, gamma_init='ones', beta_init='zeros'):
"""Initialize GroupNorm."""
super(GroupNorm, self).__init__()
self.num_groups = validator.check_positive_int(num_groups)
self.num_channels = validator.check_positive_int(num_channels)

View File

@ -27,6 +27,7 @@ class _PoolNd(Cell):
"""N-D AvgPool"""
def __init__(self, kernel_size, stride, pad_mode, data_format="NCHW"):
"""Initialize _PoolNd."""
super(_PoolNd, self).__init__()
self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.cls_name)
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name)
@ -128,6 +129,7 @@ class MaxPool2d(_PoolNd):
"""
def __init__(self, kernel_size=1, stride=1, pad_mode="valid", data_format="NCHW"):
"""Initialize MaxPool2d."""
super(MaxPool2d, self).__init__(kernel_size, stride, pad_mode, data_format)
self.max_pool = P.MaxPool(kernel_size=self.kernel_size,
strides=self.stride,
@ -196,6 +198,7 @@ class MaxPool1d(_PoolNd):
"""
def __init__(self, kernel_size=1, stride=1, pad_mode="valid"):
"""Initialize MaxPool1d."""
super(MaxPool1d, self).__init__(kernel_size, stride, pad_mode)
validator.check_value_type('kernel_size', kernel_size, [int], self.cls_name)
validator.check_value_type('stride', stride, [int], self.cls_name)
@ -288,6 +291,7 @@ class AvgPool2d(_PoolNd):
stride=1,
pad_mode="valid",
data_format="NCHW"):
"""Initialize AvgPool2d."""
super(AvgPool2d, self).__init__(kernel_size, stride, pad_mode, data_format)
self.avg_pool = P.AvgPool(kernel_size=self.kernel_size,
strides=self.stride,
@ -359,6 +363,7 @@ class AvgPool1d(_PoolNd):
kernel_size=1,
stride=1,
pad_mode="valid"):
"""Initialize AvgPool1d."""
validator.check_value_type('kernel_size', kernel_size, [int], self.cls_name)
validator.check_value_type('stride', stride, [int], self.cls_name)
self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.cls_name)

View File

@ -156,6 +156,7 @@ class _Observer(Cell):
"""
def __init__(self, quant_dtype):
"""Initialize _Observer."""
super(_Observer, self).__init__()
self.quant_dtype = quant_dtype
@ -204,6 +205,7 @@ class UniformQuantObserver(_Observer):
def __init__(self, quant_dtype=QuantDtype.INT8, per_channel=False, symmetric=False, narrow_range=False,
num_channels=1):
"""Initialize UniformQuantObserver."""
super(UniformQuantObserver, self).__init__(quant_dtype)
self.per_channel = per_channel
self.symmetric = symmetric
@ -1109,6 +1111,7 @@ class Conv2dBnWithoutFoldQuant(Cell):
bias_init='zeros',
quant_config=quant_config_default,
quant_dtype=QuantDtype.INT8):
"""Initialize Conv2dBnWithoutFoldQuant."""
super(Conv2dBnWithoutFoldQuant, self).__init__()
self.in_channels = Validator.check_positive_int(in_channels)
self.out_channels = Validator.check_positive_int(out_channels)
@ -1249,6 +1252,7 @@ class Conv2dQuant(Cell):
bias_init='zeros',
quant_config=quant_config_default,
quant_dtype=QuantDtype.INT8):
"""Initialize Conv2dQuant."""
super(Conv2dQuant, self).__init__()
self.in_channels = Validator.check_positive_int(in_channels)
self.out_channels = Validator.check_positive_int(out_channels)
@ -1380,6 +1384,7 @@ class DenseQuant(Cell):
activation=None,
quant_config=quant_config_default,
quant_dtype=QuantDtype.INT8):
"""Initialize DenseQuant."""
super(DenseQuant, self).__init__()
self.in_channels = Validator.check_positive_int(in_channels)
self.out_channels = Validator.check_positive_int(out_channels)
@ -1495,6 +1500,7 @@ class ActQuant(_QuantActivation):
fake_before=False,
quant_config=quant_config_default,
quant_dtype=QuantDtype.INT8):
"""Initialize ActQuant."""
super(ActQuant, self).__init__()
act_class = activation.__class__
act_list = [nn.ReLU, nn.ReLU6]
@ -1578,6 +1584,7 @@ class TensorAddQuant(Cell):
ema_decay=0.999,
quant_config=quant_config_default,
quant_dtype=QuantDtype.INT8):
"""Initialize TensorAddQuant."""
super(TensorAddQuant, self).__init__()
self.fake_quant_act = quant_config.activation(min_init=-6,
max_init=6,
@ -1638,6 +1645,7 @@ class MulQuant(Cell):
ema_decay=0.999,
quant_config=quant_config_default,
quant_dtype=QuantDtype.INT8):
"""Initialize MulQuant."""
super(MulQuant, self).__init__()
self.fake_quant_act = quant_config.activation(min_init=-6,
max_init=6,

View File

@ -55,19 +55,19 @@ class DenseThor(Cell):
activation (str): activate function applied to the output of the fully connected layer, eg. 'ReLU'.
Default: None.
Raises:
ValueError: If weight_init shape or bias_init shape is incorrect.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`.
Outputs:
Tensor of shape :math:`(N, out\_channels)`.
Raises:
ValueError: If the shape of `weight_init` or `bias_init` is incorrect.
Examples:
>>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
>>> x = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
>>> net = nn.DenseThor(3, 4)
>>> net(input)
>>> net(x)
[[ 2.5246444 2.2738023 0.5711005 -3.9399147 ]
[ 1.0739875 4.0155234 0.94188046 -5.459526 ]]
"""
@ -78,6 +78,7 @@ class DenseThor(Cell):
bias_init='zeros',
has_bias=True,
activation=None):
"""Initialize DenseThor."""
super(DenseThor, self).__init__()
self.thor = True
self.in_channels = Validator.check_positive_int(in_channels)
@ -117,7 +118,6 @@ class DenseThor(Cell):
self.cube_matmul = P.MatMul(transpose_a=True)
self.getG = P.InsertGradientOf(self.save_gradient)
def _process_ascend_dense_thor(self, out_channels):
"""process ascend dense thor"""
if out_channels == 1001:
@ -139,7 +139,6 @@ class DenseThor(Cell):
self.cast = P.Cast()
self.is_nsp_layer = (out_channels == 2)
def save_gradient(self, dout):
"""
this function only for thor optimizer
@ -201,6 +200,7 @@ class _ConvThor(Cell):
def __init__(self, in_channels, out_channels, kernel_size, stride, pad_mode,
padding, dilation, group, has_bias, weight_init, bias_init, transposed=False):
"""Initialize _ConvThor."""
super(_ConvThor, self).__init__()
self.in_channels = Validator.check_positive_int(in_channels)
self.out_channels = Validator.check_positive_int(out_channels)
@ -353,14 +353,15 @@ class Conv2dThor(_ConvThor):
Examples:
>>> net = nn.Conv2dThor(120, 240, 4, has_bias=False, weight_init='normal')
>>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)
>>> print(net(input).shape)
>>> x = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)
>>> print(net(x).shape)
(1, 240, 1024, 640)
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
pad_mode='same', padding=0, dilation=1, group=1, has_bias=False,
weight_init='normal', bias_init='zeros'):
"""Initialize Conv2dThor."""
kernel_size = twice(kernel_size)
stride = twice(stride)
self._dilation = dilation
@ -448,7 +449,6 @@ class Conv2dThor(_ConvThor):
self.weight_init.shape = weight_shape
self.weight = Parameter(initializer(self.weight_init, weight_shape), name='weight')
def save_gradient(self, dout):
"""save_gradient"""
out = dout
@ -474,7 +474,6 @@ class Conv2dThor(_ConvThor):
self.matrix_g_cov = matrix_g
return out
def construct(self, x):
if self.thor:
matrix_a = self.img2col(x)
@ -563,6 +562,7 @@ class EmbeddingThor(Cell):
def __init__(self, vocab_size, embedding_size, use_one_hot=False, embedding_table='normal',
dtype=mstype.float32, padding_idx=None):
"""Initialize EmbeddingThor."""
super(EmbeddingThor, self).__init__()
self.vocab_size = Validator.check_value_type('vocab_size', vocab_size, [int], self.cls_name)
self.embedding_size = Validator.check_value_type('embedding_size', embedding_size, [int], self.cls_name)
@ -602,7 +602,6 @@ class EmbeddingThor(Cell):
self.cube_matmul = P.MatMul(transpose_a=True)
self.mul = P.Mul()
def save_gradient(self, dout):
"""
this function only for thor optimizer
@ -634,11 +633,9 @@ class EmbeddingThor(Cell):
else:
output_for_reshape = self.gather(self.embedding_table, flat_ids, 0)
output = self.reshape(output_for_reshape, out_shape)
return output
def extend_repr(self):
s = 'vocab_size={}, embedding_size={}, use_one_hot={}, embedding_table={}, dtype={}, padding_idx={}'.format(
self.vocab_size, self.embedding_size, self.use_one_hot, self.embedding_table, self.dtype, self.padding_idx)

View File

@ -85,15 +85,16 @@ class TimeDistributed(Cell):
TypeError: If layer is not a Cell or Primitive.
Examples:
>>> input = Tensor(np.random.random([32, 10, 3]), mindspore.float32)
>>> x = Tensor(np.random.random([32, 10, 3]), mindspore.float32)
>>> dense = nn.Dense(3, 6)
>>> net = nn.TimeDistributed(dense, time_axis=1, reshape_with_axis=0)
>>> output = net(input)
>>> output = net(x)
>>> print(output.shape)
(32, 10, 6)
"""
def __init__(self, layer, time_axis, reshape_with_axis=None):
"""Initialize TimeDistributed."""
if not isinstance(layer, (Cell, Primitive)):
raise TypeError("Please initialize TimeDistributed with mindspore.nn.Cell or "
"mindspore.ops.Primitive instance. You passed: {input}".format(input=layer))

View File

@ -46,6 +46,7 @@ class Loss(Cell):
``Ascend`` ``GPU`` ``CPU``
"""
def __init__(self, reduction='mean'):
"""Initialize Loss."""
super(Loss, self).__init__()
if reduction not in ('mean', 'sum', 'none'):
@ -96,6 +97,7 @@ class _Loss(Loss):
Base class for other losses.
"""
def __init__(self, reduction='mean'):
"""Initialize _Loss."""
log.warning("'_Loss' is deprecated from version 1.3 and "
"will be removed in a future version, use 'Loss' instead.")
super(_Loss, self).__init__()
@ -150,6 +152,7 @@ class L1Loss(Loss):
0.33333334
"""
def __init__(self, reduction='mean'):
"""Initialize L1Loss."""
super(L1Loss, self).__init__(reduction)
self.abs = P.Abs()
@ -238,6 +241,7 @@ class RMSELoss(Loss):
0.57735026
"""
def __init__(self):
"""Initialize RMSELoss."""
super(RMSELoss, self).__init__()
self.MSELoss = MSELoss()
@ -285,6 +289,7 @@ class MAELoss(Loss):
0.33333334
"""
def __init__(self, reduction='mean'):
"""Initialize MAELoss."""
super(MAELoss, self).__init__(reduction)
self.abs = P.Abs()
@ -347,6 +352,7 @@ class SmoothL1Loss(Loss):
[0. 0. 0.5]
"""
def __init__(self, beta=1.0):
"""Initialize SmoothL1Loss."""
super(SmoothL1Loss, self).__init__()
self.beta = beta
self.smooth_l1_loss = P.SmoothL1Loss(self.beta)
@ -418,6 +424,7 @@ class SoftmaxCrossEntropyWithLogits(Loss):
def __init__(self,
sparse=False,
reduction='none'):
"""Initialize SoftmaxCrossEntropyWithLogits."""
super(SoftmaxCrossEntropyWithLogits, self).__init__(reduction)
self.sparse = validator.check_bool(sparse, "sparse")
self.reduction = reduction
@ -481,6 +488,7 @@ class DiceLoss(Loss):
0.38596618
"""
def __init__(self, smooth=1e-5):
"""Initialize DiceLoss."""
super(DiceLoss, self).__init__()
self.smooth = validator.check_positive_float(smooth, "smooth")
self.reshape = P.Reshape()
@ -559,6 +567,7 @@ class MultiClassDiceLoss(Loss):
0.3283009
"""
def __init__(self, weights=None, ignore_indiex=None, activation="softmax"):
"""Initialize MultiClassDiceLoss."""
super(MultiClassDiceLoss, self).__init__()
activation_list = ['softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'sigmoid']
@ -604,7 +613,7 @@ class SampledSoftmaxLoss(Loss):
Args:
num_sampled (int): The number of classes to randomly sample per batch.
num_classes (int): The number of possible classes.
num_true (int): The number of target classes per training example.
num_true (int): The number of target classes per training example. Default: 1.
sampled_values (Union[list, tuple]): List or tuple of (`sampled_candidates`, `true_expected_count`,
`sampled_expected_count`) returned by a `*CandidateSampler` function.
Default to None, `UniformCandidateSampler` is applied.
@ -650,6 +659,7 @@ class SampledSoftmaxLoss(Loss):
def __init__(self, num_sampled, num_classes, num_true=1,
sampled_values=None, remove_accidental_hits=True, seed=0,
reduction='none'):
"""Initialize SampledSoftmaxLoss."""
super(SampledSoftmaxLoss, self).__init__(reduction)
if num_true < 1:
@ -877,6 +887,7 @@ class BCELoss(Loss):
"""
def __init__(self, weight=None, reduction='none'):
"""Initialize BCELoss."""
super(BCELoss, self).__init__()
self.binary_cross_entropy = P.BinaryCrossEntropy(reduction=reduction)
self.weight_one = weight is None
@ -946,6 +957,7 @@ class CosineEmbeddingLoss(Loss):
0.0003426075
"""
def __init__(self, margin=0.0, reduction="mean"):
"""Initialize CosineEmbeddingLoss."""
super(CosineEmbeddingLoss, self).__init__(reduction)
self.reduce_sum = P.ReduceSum()
self.maximum = P.Maximum()
@ -1035,6 +1047,7 @@ class BCEWithLogitsLoss(Loss):
"""
def __init__(self, reduction='mean', weight=None, pos_weight=None):
"""Initialize BCEWithLogitsLoss."""
super(BCEWithLogitsLoss, self).__init__()
self.bce_with_logits_loss = P.BCEWithLogitsLoss(reduction=reduction)
if isinstance(weight, Parameter):
@ -1139,6 +1152,7 @@ class FocalLoss(Loss):
"""
def __init__(self, weight=None, gamma=2.0, reduction='mean'):
"""Initialize FocalLoss."""
super(FocalLoss, self).__init__(reduction=reduction)
self.gamma = validator.check_value_type("gamma", gamma, [float])

View File

@ -57,6 +57,7 @@ class SparseToDense(Cell):
"""
def __init__(self):
"""Initialize SparseToDense."""
super(SparseToDense, self).__init__()
self.sparse_to_dense = P.SparseToDense()

View File

@ -29,7 +29,7 @@ from ..primitive import constexpr
from ... import context
from ...common import dtype as mstype
from ...common.tensor import RowTensor
from .._utils.utils import range_op, get_1d_shape
from .._utils.utils import range_op, get_1d_shape, generate_shape_index
reduce_sum = P.ReduceSum()
unsorted_segment_sum = P.UnsortedSegmentSum()
@ -358,18 +358,6 @@ def get_bprop_slice(self):
return bprop
@constexpr
def _generate_shape_index(out_shape, indices_shape, axis):
out_rank = len(out_shape)
ind_rank = len(indices_shape)
if axis < 0:
axis += out_rank - ind_rank + 1
perm_part1 = tuple(range(axis, axis + ind_rank))
index = tuple(range(out_rank))
perm = perm_part1 + index[:axis] + index[axis + ind_rank:]
return perm
@constexpr
def _generate_inverse_index(x_shape, axis):
x_rank = len(x_shape)
@ -409,7 +397,7 @@ def get_bprop_gather_v2(self):
out_shp = shape_op(dout)
ind_shp = shape_op(indices)
# Example: out_shape:(3,2,3) axis 1 -> (1,0,2)
perm_1 = _generate_shape_index(out_shp, ind_shp, axis)
perm_1 = generate_shape_index(out_shp, ind_shp, axis)
values_transpose = transpose(dout, perm_1)
if -1 in shape_op(x):
params_grad = unsorted_segment_sum(values_transpose, indices, dyn_shape_op(x)[axis])
@ -488,7 +476,7 @@ def get_bprop_sparse_gather_v2(self):
out_shp = shape_op(dout)
ind_shp = shape_op(indices)
# Example: out_shape:(3,2,3) axis 1 -> (1,0,2)
perm_1 = _generate_shape_index(out_shp, ind_shp, axis)
perm_1 = generate_shape_index(out_shp, ind_shp, axis)
values_transpose = transpose(dout, perm_1)
params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis])
# Example: out_shape:(3,2,3) axis 2 -> (1,2,0)
@ -680,7 +668,7 @@ def get_bprop_oneslike(self):
@bprop_getters.register(P.ZerosLike)
def get_bprop_zeroslike(self):
"""Generate bprop for OnesLike"""
"""Generate bprop for ZerosLike"""
def bprop(x, out, dout):
return (zeros_like(x),)

View File

@ -26,17 +26,9 @@ get_dtype = P.DType()
@bprops.register("MaximumGrad")
def bprop_maximum_grad_grad(x, y, z, out, dout):
"""Backpropagator for primitive `MaximumGrad`."""
out0 = F.cast(out[0] != 0, get_dtype(dout[0]))
out1 = F.cast(out[1] != 0, get_dtype(dout[1]))
dz = out0 * dout[0] + out1 * dout[1]
return F.zeros_like(x), F.zeros_like(y), dz
@bprops.register("MinimumGrad")
def bprop_minimum_grad_grad(x, y, z, out, dout):
"""Backpropagator for primitive `MinimumGrad`."""
def bprop_max_and_minimum_grad_grad(x, y, z, out, dout):
"""Backpropagator for primitive `MaximumGrad` and `MinimumGrad`."""
out0 = F.cast(out[0] != 0, get_dtype(dout[0]))
out1 = F.cast(out[1] != 0, get_dtype(dout[1]))
dz = out0 * dout[0] + out1 * dout[1]

View File

@ -298,7 +298,6 @@ def get_bprop_floor(self):
bc_x = fill_(dtype_(x), shape_(x), 0.)
return (bc_x,)
return bprop
@ -420,7 +419,7 @@ def get_bprop_xlogy(self):
@bprop_getters.register(P.SquareSumAll)
def get_bprop_square_sum_all(self):
"""Grad definition for `Square` operation."""
"""Grad definition for `SquareSumAll` operation."""
mul_func = P.Mul()
fill_func = P.Fill()
dtype = P.DType()

View File

@ -622,6 +622,7 @@ def get_bprop_tanh_grad(self):
return bprop
@bprop_getters.register(P.Gelu)
@bprop_getters.register(P.GeLU)
def get_bprop_gelu(self):
"""Grad definition for `GeLU` operation."""
@ -634,18 +635,6 @@ def get_bprop_gelu(self):
return bprop
@bprop_getters.register(P.Gelu)
def get_bprop_gelu_2(self):
"""Grad definition for `GeLU` operation."""
input_grad = G.GeLUGrad()
def bprop(x, out, dout):
dx = input_grad(dout, x, out)
return (dx,)
return bprop
@bprop_getters.register(P.FastGeLU)
def get_bprop_fast_gelu(self):
"""Grad definition for `FastGeLU` operation."""
@ -1156,28 +1145,9 @@ def get_bprop_dropout(self):
@bprop_getters.register(P.Dropout2D)
def get_bprop_dropout2d(self):
"""Grad definition for `Dropout2D` operation."""
dtype = P.DType()
cast = P.Cast()
mul = P.Mul()
keep_prob = self.keep_prob
def bprop(x, out, dout):
_, mask = dout
y = cast(mask, mstype.float32)
if keep_prob != 0:
y = y * (1 / keep_prob)
y = mul(x, y)
y = cast(y, dtype(x))
return (y,)
return bprop
@bprop_getters.register(P.Dropout3D)
def get_bprop_dropout3d(self):
"""Grad definition for `Dropout3D` operation."""
"""Grad definition for `Dropout2D` and `Dropout3D` operation."""
dtype = P.DType()
cast = P.Cast()
mul = P.Mul()

View File

@ -28,6 +28,7 @@ from .grad_base import bprop_getters
@bprop_getters.register(P.Assign)
def get_bprop_assign(self):
"""Generate bprop for Assign"""
def bprop(x, y, out, dout):
return (dout, zeros_like(y))
return bprop
@ -88,6 +89,7 @@ def get_bprop_sync_batch_norm(self):
@bprop_getters.register(inner.GpuConvertToDynamicShape)
def get_bprop_gpu_convert_to_dynamic_shape(self):
"""Get backprop for GpuConvertToDynamicShape."""
def bprop(x, out, dout):
return (dout,)
return bprop

View File

@ -139,6 +139,7 @@ def get_bprop_BatchNormFold(self):
@bprop_getters.register(P.BNTrainingReduce)
def get_bprop_BNTrainingReduce(self):
"""Generate bprop for BNTrainingReduce for Ascend"""
def bprop(x, out, dout):
return (zeros_like(x),)
@ -199,6 +200,7 @@ def get_bprop_acts_ulq(self):
@bprop_getters.register(Q.WtsARQ)
def get_bprop_wts_arq(self):
"""Grad definition for 'WtsArq' operation"""
def bprop(w, w_min, w_max, out, dout):
return (dout, zeros_like(w_min), zeros_like(w_max))

View File

@ -107,3 +107,15 @@ def get_1d_shape(in_shape):
for i in in_shape:
out_shape *= i
return (out_shape,)
@constexpr
def generate_shape_index(out_shape, indices_shape, axis):
out_rank = len(out_shape)
ind_rank = len(indices_shape)
if axis < 0:
axis += out_rank - ind_rank + 1
perm_part1 = tuple(range(axis, axis + ind_rank))
index = tuple(range(out_rank))
perm = perm_part1 + index[:axis] + index[axis + ind_rank:]
return perm

View File

@ -316,6 +316,7 @@ class GradOperation(GradOperation_):
"""
def __init__(self, get_all=False, get_by_list=False, sens_param=False):
"""Initialize GradOperation."""
if not isinstance(get_all, bool):
raise TypeError(f'get_all should be bool, but got {type(get_all)}')
if not isinstance(get_by_list, bool):
@ -425,6 +426,7 @@ class MultitypeFuncGraph(MultitypeFuncGraph_):
"""
def __init__(self, name, read_value=False):
"""Initialize MultitypeFuncGraph."""
MultitypeFuncGraph_.__init__(self, name)
self.entries = list()
if read_value:
@ -523,6 +525,7 @@ class HyperMap(HyperMap_):
"""
def __init__(self, ops=None):
"""Initialize HyperMap."""
self.ops = ops
if ops:
HyperMap_.__init__(self, ops)
@ -586,6 +589,7 @@ class Map(Map_):
"""
def __init__(self, ops=None):
"""Initialize Map."""
self.ops = ops
if ops:
Map_.__init__(self, ops)
@ -610,6 +614,7 @@ class _ListAppend(ListAppend_):
"""
def __init__(self, name):
"""Initialize _ListAppend."""
ListAppend_.__init__(self, name)
def __call__(self, *args):
@ -628,6 +633,7 @@ class _Tail(Tail_):
"""
def __init__(self, name):
"""Initialize _Tail."""
Tail_.__init__(self, name)
def __call__(self, *args):
@ -641,6 +647,7 @@ class _ZipOperation(ZipOperation_):
"""Generates a tuple of zip iterations for inputs."""
def __init__(self, name):
"""Initialize _ZipOperation."""
ZipOperation_.__init__(self, name)
def __call__(self, *args):

View File

@ -115,6 +115,7 @@ class _ClipByGlobalNorm(Cell):
"""
def __init__(self, clip_norm=1.0, use_norm=None):
"""Initialize _ClipByGlobalNorm."""
super(_ClipByGlobalNorm, self).__init__()
# Add interface. This parameter is not used at present
if use_norm is not None:

View File

@ -135,7 +135,7 @@ def _axes_int_check(x1_shape, x2_shape, axes):
raise ValueError(f"axes must be at least 0 for tensor dot, got {axes}")
if axes == 0:
# outer product, no input validation required
return ([], [])
return [], []
if axes > len(x1_shape) or axes > len(x2_shape):
raise ValueError(
"Axes value too high for given input arrays dimensions.")
@ -599,7 +599,7 @@ def _check_matmul_shapes(shape1, shape2):
@constexpr
def _tile_size(shape, out_shape, ndim):
"""Returns tile_size such that shape*tile_size = out_shape"""
size = [1]*ndim
size = [1] * ndim
for idx, (i, j) in enumerate(zip(shape, out_shape)):
if i != j:
size[idx] = j

View File

@ -450,7 +450,6 @@ def _tensor_getitem_by_tuple(data, tuple_index, op_name):
def _generate_indices_from_tuple_of_tensor(tuple_index, op_name):
"""Generate an indices tensor from a tuple of tensor."""
indices = None
indexes_types = hyper_map(F.dtype, tuple_index)
const_utils.check_types_valid(indexes_types, mstype.int_type, op_name)
tensor_index_shape = hyper_map(F.shape, tuple_index)
@ -615,11 +614,10 @@ def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value):
"""Set a tensor item by a bool tensor with a tensor."""
index_shape = F.shape(index)
data_shape = F.shape(data)
data_shape = const_utils.check_equal(data_shape, index_shape,
"The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
const_utils.check_equal(data_shape, index_shape,
"The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
size = F.shape_mul(F.shape(value))
size = const_utils.check_equal(1, size,
"When assign value is a tensor, its size should be {}, but current size is {}.")
const_utils.check_equal(1, size, "When assign value is a tensor, its size should be {}, but current size is {}.")
dtype = F.dtype(data)
u_cast = F.cast(value, dtype)
one_data = F.ones_like(data)

View File

@ -506,7 +506,6 @@ def generate_index_info_from_tuple_of_mixed_tensors(tensor_positions, tensor_ind
tensor_index_continue_tag = _judge_order_continuous(tensor_positions)
fancy_position = tensor_positions[0] if tensor_index_continue_tag else 0
broadcast_shape = generate_broadcast_shape(tensor_indexes_shapes, op_name)
index_tensor_new_shape, final_shape = [], []
final_shape = slice_shapes[:fancy_position] + broadcast_shape + slice_shapes[fancy_position:]
index_tensor_new_shape = (1,) * len(slice_shapes[:fancy_position]) + \
@ -620,14 +619,14 @@ def get_stride_info_from_tuple(data_shape, tuple_index):
if ellipsis_count > 1:
raise IndexError("An index can have only one ellipsis (...)")
ellipsis_range_size = data_dim - tuple_index_len + 1
begin_strides.extend([0] * (ellipsis_range_size))
begin_strides.extend([0] * ellipsis_range_size)
end_strides.extend(
[shape for shape in data_shape[index_count: index_count + ellipsis_range_size]])
step_strides.extend([1] * (ellipsis_range_size))
step_strides.extend([1] * ellipsis_range_size)
index_count = index_count + ellipsis_range_size
else:
raise IndexError("Not supported index data type, got ",
index, " type is ", type(item))
index, " type is ", type(index))
for index in range(index_count, data_dim):
begin_strides.append(0)
end_strides.append(data_shape[index])

View File

@ -44,6 +44,7 @@ class _TupleAdd(base.TupleAdd_):
"""
def __init__(self, name):
"""Initialize _TupleAdd."""
base.TupleAdd_.__init__(self, name)
def __call__(self, *args):

View File

@ -39,6 +39,7 @@ class _TupleSlice(base.TupleSlice_):
"""
def __init__(self, name):
"""Initialize _TupleSlice."""
base.TupleSlice_.__init__(self, name)
def __call__(self, *args):
@ -61,6 +62,7 @@ class _TupleGetItemTensor(base.TupleGetItemTensor_):
"""
def __init__(self, name):
"""Initialize _TupleGetItemTensor."""
base.TupleGetItemTensor_.__init__(self, name)
def __call__(self, *args):

View File

@ -51,4 +51,4 @@ def _greater_equal_tensor(x, y):
Returns:
Tensor, return value by operator P.GreaterEqual.
"""
return F.tensor_ge(x, y)
return F.tensor_ge(x, y)

View File

@ -51,4 +51,4 @@ def _less_equal_tensor(x, y):
Returns:
Tensor, return value by operator P.LessEqual.
"""
return F.tensor_le(x, y)
return F.tensor_le(x, y)

View File

@ -49,4 +49,4 @@ def _logical_and_tensor(x, y):
Returns:
Tensor, Return logical and operation result of x and y.
"""
return F.logical_and(x, y)
return F.logical_and(x, y)

View File

@ -68,10 +68,6 @@ class RegOp:
Args:
op_name (str): Name of op.
inputs (list): Inputs information of the op.
outputs (list): Outputs information of the op.
attr_ (list): Attribute information of the op.
dtype_format_ (list): Dtype and format information of the op.
"""
def __init__(self, op_name=""):
@ -343,64 +339,13 @@ class AkgAscendRegOp(AkgRegOp):
super(AkgAscendRegOp, self).__init__(op_name, "AiCore")
class AiCPURegOp(RegOp):
class AiCPURegOp(CpuRegOp):
"""Class for AiCPU op info register"""
def __init__(self, op_name):
super(AiCPURegOp, self).__init__(op_name)
self.imply_type = "AiCPU"
def input(self, index=None, name=None, param_type=None, **kwargs):
"""
Register AiCPU op input information.
Args:
index (int): Order of the input. Default: None.
name (str): Name of the input. Default: None.
param_type (str): Param type of the input. Default: None.
kwargs (dict): Other information of the input.
"""
param_list = [index, name, param_type]
key_list = ["index", "name", "param_type"]
fn_list = [self._is_int, self._is_string, self._is_string]
input_dict = self._check_param(param_list, key_list, fn_list, kwargs)
self.inputs.append(input_dict)
return self
def output(self, index=None, name=None, param_type=None, **kwargs):
"""
Register AiCPU op output information.
Args:
index (int): Order of the output. Default: None.
name (str): Name of the output. Default: None.
param_type (str): Param type of the output. Default: None.
kwargs (dict): Other information of the output.
"""
param_list = [index, name, param_type]
key_list = ["index", "name", "param_type"]
fn_list = [self._is_int, self._is_string, self._is_string]
output_dict = self._check_param(param_list, key_list, fn_list, kwargs)
self.outputs.append(output_dict)
return self
def attr(self, name=None, value_type=None, value=None, **kwargs):
"""
Register AiCPU op attribute information.
Args:
name (str): Name of the attribute. Default: None.
value_type (str): Value type of the attribute. Default: None.
value (str): Value of the attribute. Default: None.
kwargs (dict): Other information of the attribute.
"""
param_list = [name, value_type, value]
key_list = ["name", "type", "value"]
fn_list = [self._is_string]
attr_dict = self._check_param(param_list, key_list, fn_list, kwargs)
self.attr_.append(attr_dict)
return self
class TBERegOp(RegOp):
"""Class for TBE operator information register."""
@ -419,7 +364,7 @@ class TBERegOp(RegOp):
self.is_dynamic_format_ = False
self.op_pattern_ = ""
def async_flag(self, async_flag):
def async_flag(self, async_flag=False):
"""
Define the calculation efficiency of the operator, whether the asynchronous calculation is supported.
@ -441,7 +386,7 @@ class TBERegOp(RegOp):
self.binfile_name_ = binfile_name
return self
def compute_cost(self, compute_cost):
def compute_cost(self, compute_cost=10):
"""
Define the calculation efficiency of operator, which refers to the value of the cost model
in the tiling module.
@ -464,7 +409,7 @@ class TBERegOp(RegOp):
self.kernel_name_ = kernel_name
return self
def partial_flag(self, partial_flag):
def partial_flag(self, partial_flag=True):
"""
Define the calculation efficiency of operator, whether the partial calculation is supported.
@ -486,7 +431,7 @@ class TBERegOp(RegOp):
self.reshape_type_ = reshape_type
return self
def dynamic_shape(self, dynamic_shape):
def dynamic_shape(self, dynamic_shape=False):
"""
Whether the operator supports dynamic shape.
@ -497,7 +442,7 @@ class TBERegOp(RegOp):
self.dynamic_shape_ = dynamic_shape
return self
def need_check_supported(self, need_check_supported):
def need_check_supported(self, need_check_supported=False):
"""
Whether the operator need check supports.
@ -508,7 +453,7 @@ class TBERegOp(RegOp):
self.need_check_supported_ = need_check_supported
return self
def is_dynamic_format(self, is_dynamic_format):
def is_dynamic_format(self, is_dynamic_format=False):
"""
Whether the operator need calop_select_format api.

View File

@ -154,22 +154,9 @@ class RsqrtGrad(PrimitiveWithInfer):
return x_dtype
class SoftmaxGrad(PrimitiveWithInfer):
class SoftmaxGrad(ReciprocalGrad):
"""Performs grad of Softmax operation."""
@prim_attr_register
def __init__(self):
"""Initialize SoftmaxGrad"""
def infer_shape(self, x_shape, dout_shape):
validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name)
return x_shape
def infer_dtype(self, x_dtype, dout_dtype):
args = {"x": x_dtype, "dout": dout_dtype}
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
return x_dtype
class SqrtGrad(PrimitiveWithInfer):
"""Performs grad of Sqrt operation."""
@ -347,7 +334,6 @@ class Conv3DBackpropFilter(PrimitiveWithInfer):
(32, 32, 4, 6, 2)
"""
@prim_attr_register
def __init__(self,
out_channel,
@ -644,7 +630,7 @@ class DropoutGrad(PrimitiveWithInfer):
Args:
keep_prob (float): The keep rate, between 0 and 1, e.g. keep_prob = 0.9,
means dropping out 10% of input units.
means dropping out 10% of input units. Default: 0.5.
Inputs:
- **shape** (tuple[int]) - The shape of target mask.
@ -1030,6 +1016,9 @@ class MaximumGrad(Primitive):
def __init__(self, grad_x=True, grad_y=True):
"""Initialize MaximumGrad"""
def __call__(self, x, y, dout):
raise NotImplementedError
class MaxPoolGradWithArgmax(_PoolGrad):
"""Computes the gradients of MaxPoolWithArgmax."""
@ -1688,6 +1677,7 @@ class ROIAlignGrad(PrimitiveWithInfer):
ROIAlignGrad operator.
Args:
xdiff_shape (tuple): The diff shape.
pooled_height (int): The output feature height.
pooled_width (int): The output feature width.
spatial_scale (float): The feature stride.
@ -1892,7 +1882,7 @@ class StridedSliceGrad(PrimitiveWithInfer):
class SoftplusGrad(PrimitiveWithInfer):
"""Computes gradient for the Log Softmax activation."""
"""Computes gradient for the Softplus activation."""
@prim_attr_register
def __init__(self):

View File

@ -78,13 +78,14 @@ class ExtractImagePatches(PrimitiveWithInfer):
def infer_shape(self, input_x):
"""infer shape"""
if len(input_x) != 4:
raise ValueError("The `input_x` should be a 4-D tensor, "
f"but got a {len(input_x)}-D tensor whose shape is {input_x}")
in_batch, in_depth, in_row, in_col = input_x
_, _, ksize_row, ksize_col = self.ksizes
_, _, stride_row, stride_col = self.strides
_, _, rate_row, rate_col = self.rates
if len(input_x) != 4:
raise ValueError("The `input_x` should be a 4-D tensor, "
f"but got a {len(input_x)}-D tensor whose shape is {input_x}")
out_batch = in_batch
out_depth = ksize_row * ksize_col * in_depth
@ -124,7 +125,7 @@ class Range(PrimitiveWithInfer):
start (float): If `limit` is `None`, the value acts as limit in the range and first entry
defaults to `0`. Otherwise, it acts as first entry in the range.
limit (float): Acts as upper limit of sequence. If `None`, defaults to the value of `start`
while set the first entry of the range to `0`. It can not be equal to `start`.
while set the first entry of the range to `0`. It can not be equal to `start`. Default: None.
delta (float): Increment of the range. It can not be equal to zero. Default: 1.0.
Inputs:
@ -134,9 +135,9 @@ class Range(PrimitiveWithInfer):
Tensor, has the same shape and dtype as `input_x`.
Examples:
>>> range = ops.Range(1.0, 8.0, 2.0)
>>> range_op = ops.Range(1.0, 8.0, 2.0)
>>> x = Tensor(np.array([1, 2, 3, 2]), mindspore.int32)
>>> output = range(x)
>>> output = range_op(x)
>>> print(output)
[3, 5, 7, 5]
"""
@ -906,7 +907,7 @@ class StackInit(PrimitiveWithInfer):
at the top of the stack using `StackPop`. Finally, the stack should be destroyed with `StackDestroy`.
Args:
index (int): The index of the stack.
index (int): The index of the stack. Default: 1.
Supported Platforms:
``Ascend``
@ -940,7 +941,7 @@ class StackPush(PrimitiveWithInfer):
Please refer to the usage in source code of `StackInit`.
Args:
index (int): The index of the stack.
index (int): The index of the stack. Default: 1.
Inputs:
- **input** (Tensor) - A tensor to be pushed onto the stack.
@ -966,9 +967,9 @@ class StackPop(PrimitiveWithInfer):
Please refer to the usage in source code of `StackInit`.
Args:
index (int): The index of the stack.
shape (tuple): The shape of the tensor at the top of the stack.
dtype (mindspore.dtype): The type of the tensor at the top of the stack.
index (int): The index of the stack. Default: 1.
shape (tuple): The shape of the tensor at the top of the stack. Default: (1,).
dtype (mindspore.dtype): The type of the tensor at the top of the stack. Default: mindspore.float32.
Outputs:
- **output** (Tensor) - The tensor at the top of the stack.
@ -1010,7 +1011,7 @@ class StackDestroy(PrimitiveWithInfer):
Please refer to the usage in source code of `StackInit`.
Args:
index (int): The index of the stack.
index (int): The index of the stack. Default: 1.
Supported Platforms:
``Ascend``

View File

@ -136,9 +136,9 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer):
Examples:
>>> x = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
>>> min = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
>>> max = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
>>> output_tensor = MinMaxUpdatePerChannel(num_bits=8)(x, min, max)
>>> min_value = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
>>> max_value = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
>>> output_tensor = MinMaxUpdatePerChannel(num_bits=8)(x, min_value, max_value)
"""
support_quant_bit = [4, 7, 8]
ascend_support_x_rank = [2, 4]
@ -519,7 +519,7 @@ class FakeQuantWithMinMaxVars(PrimitiveWithInfer):
>>> max_tensor = Tensor(np.array([6]), mstype.float32)
>>> output_tensor = FakeQuantWithMinMaxVars(num_bits=8, narrow_range=False)(
... input_tensor, min_tensor, max_tensor)
>>> output_tensor shape: (3, 16, 5, 5) data type: mstype.float32
>>> output_tensor # shape: (3, 16, 5, 5) data type: mstype.float32
"""
@prim_attr_register
@ -581,9 +581,9 @@ class FakeQuantWithMinMaxVarsGradient(PrimitiveWithInfer):
>>> max_tensor = Tensor(np.array([6]), mstype.float32)
>>> x_gradient, min_gradient, max_gradient = FakeQuantWithMinMaxVarsGradient(num_bits=8,narrow_range=False)
... (gradients, input_tensor, min_tensor, max_tensor)
>>> x_gradient shape: (3, 16, 5, 5) data type: mstype.float32
>>> min_gradient shape: (1,) data type: mstype.float32
>>> max_gradient shape: (1,) data type: mstype.float32
>>> x_gradient # shape: (3, 16, 5, 5) data type: mstype.float32
>>> min_gradient # shape: (1,) data type: mstype.float32
>>> max_gradient # shape: (1,) data type: mstype.float32
"""
@prim_attr_register
@ -642,7 +642,7 @@ class FakeQuantWithMinMaxVarsPerChannel(PrimitiveWithInfer):
>>> max_tensor = Tensor(np.array([6, 1, 2, 3]), mstype.float32)
>>> output_tensor = FakeQuantWithMinMaxVars(num_bits=8, narrow_range=False)(
... input_tensor, min_tensor, max_tensor)
>>> output_tensor shape: (3, 16, 3, 4) data type: mstype.float32
>>> output_tensor # shape: (3, 16, 3, 4) data type: mstype.float32
"""
@prim_attr_register
@ -698,9 +698,9 @@ class FakeQuantWithMinMaxVarsPerChannelGradient(PrimitiveWithInfer):
>>> x_gradient, min_gradient, max_gradient = FakeQuantWithMinMaxVarsPerChannelGradient(
... num_bits=8, narrow_range=False)(
... gradients, input_tensor, min_tensor, max_tensor)
>>> x_gradient shape: (3, 16, 3, 4) data type: mstype.float32
>>> min_gradient shape: (4,) data type: mstype.float32
>>> max_gradient shape: (4,) data type: mstype.float32
>>> x_gradient # shape: (3, 16, 3, 4) data type: mstype.float32
>>> min_gradient # shape: (4,) data type: mstype.float32
>>> max_gradient # shape: (4,) data type: mstype.float32
"""
@prim_attr_register
@ -728,6 +728,28 @@ class FakeQuantWithMinMaxVarsPerChannelGradient(PrimitiveWithInfer):
return x_type, min_type, max_type
def _fake_quant_per_infer_dtype(prim_name, x_type, min_type, max_type):
if context.get_context('device_target') == "GPU":
valid_dtypes = (mstype.float32,)
else:
valid_dtypes = (mstype.float16, mstype.float32)
tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=prim_name),
("x", "min", "max"),
(x_type, min_type, max_type)))
return x_type
def _fake_quant_per_grad_infer_dtype(prim_name, dout_type, x_type, min_type, max_type):
if context.get_context('device_target') == "GPU":
valid_dtypes = (mstype.float32,)
else:
valid_dtypes = (mstype.float16, mstype.float32)
tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=prim_name),
("dout", "x", "min", "max"),
(dout_type, x_type, min_type, max_type)))
return dout_type
class FakeQuantPerLayer(PrimitiveWithInfer):
r"""
Simulates the quantize and dequantize operations in training time.
@ -797,19 +819,12 @@ class FakeQuantPerLayer(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_type, min_type, max_type):
if context.get_context('device_target') == "GPU":
valid_dtypes = (mstype.float32,)
else:
valid_dtypes = (mstype.float16, mstype.float32)
tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name),
("x", "min", "max"),
(x_type, min_type, max_type)))
return x_type
return _fake_quant_per_infer_dtype(self.name, x_type, min_type, max_type)
class FakeQuantPerLayerGrad(PrimitiveWithInfer):
r"""
Performs grad of FakeQuantPerLayerGrad operation.
Performs grad of FakeQuantPerLayer operation.
Examples:
>>> fake_min_max_grad = FakeQuantPerLayerGrad()
@ -852,14 +867,7 @@ class FakeQuantPerLayerGrad(PrimitiveWithInfer):
return dout_shape
def infer_dtype(self, dout_type, x_type, min_type, max_type):
if context.get_context('device_target') == "GPU":
valid_dtypes = (mstype.float32,)
else:
valid_dtypes = (mstype.float16, mstype.float32)
tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name),
("dout", "x", "min", "max"),
(dout_type, x_type, min_type, max_type)))
return dout_type
return _fake_quant_per_grad_infer_dtype(self.name, dout_type, x_type, min_type, max_type)
class FakeQuantPerChannel(PrimitiveWithInfer):
@ -946,19 +954,12 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_type, min_type, max_type):
if context.get_context('device_target') == "GPU":
valid_dtypes = (mstype.float32,)
else:
valid_dtypes = (mstype.float16, mstype.float32)
tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name),
("x", "min", "max"),
(x_type, min_type, max_type)))
return x_type
return _fake_quant_per_infer_dtype(self.name, x_type, min_type, max_type)
class FakeQuantPerChannelGrad(PrimitiveWithInfer):
r"""
Performs grad of FakeQuantPerChannelGrad operation.
Performs grad of FakeQuantPerChannel operation.
Examples:
>>> fqmmpc_grad = FakeQuantPerChannelGrad()
@ -1001,14 +1002,7 @@ class FakeQuantPerChannelGrad(PrimitiveWithInfer):
return dout_shape
def infer_dtype(self, dout_type, x_type, min_type, max_type):
if context.get_context('device_target') == "GPU":
valid_dtypes = (mstype.float32,)
else:
valid_dtypes = (mstype.float16, mstype.float32)
tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name),
("dout", "x", "min", "max"),
(dout_type, x_type, min_type, max_type)))
return dout_type
return _fake_quant_per_grad_infer_dtype(self.name, dout_type, x_type, min_type, max_type)
class BatchNormFold(PrimitiveWithInfer):
@ -1298,7 +1292,7 @@ class BatchNormFold2(PrimitiveWithInfer):
class BatchNormFold2Grad(PrimitiveWithInfer):
r"""
Performs grad of CorrectionAddGrad operation.
Performs grad of BatchNormFold2 operation.
Examples:
>>> bnf2_grad = ops.BatchNormFold2Grad()
@ -1386,7 +1380,7 @@ class BatchNormFoldD(PrimitiveWithInfer):
class BatchNormFoldGradD(PrimitiveWithInfer):
"""Performs grad of _BatchNormFoldGrad operation."""
"""Performs grad of BatchNormFold operation."""
@prim_attr_register
def __init__(self, epsilon=1e-5, is_training=True, freeze_bn=0):
@ -1460,7 +1454,7 @@ class BatchNormFold2D(PrimitiveWithInfer):
class BatchNormFold2GradD(PrimitiveWithInfer):
"""Performs grad of CorrectionAddGrad operation."""
"""Performs grad of BatchNormFold2 operation."""
channel_axis = 1
@prim_attr_register
@ -1678,7 +1672,6 @@ class WtsARQ(PrimitiveWithInfer):
The WtsARQ(Weights Adaptive Range Quantization).
Args:
axes (list): Specify channels for ARQ algorithm.
num_bits (int): The bits num used for quantize.
offset_flag (bool): Whether use offset for quantize.

View File

@ -829,7 +829,7 @@ class Gather(Primitive):
@prim_attr_register
def __init__(self):
"""Initialize index_select"""
"""Initialize Gather"""
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
@ -839,11 +839,10 @@ class GatherV2(PrimitiveWithCheck):
Please use Gather instead.
"""
@deprecated("1.1", "Gather", True)
@prim_attr_register
def __init__(self):
"""Initialize index_select"""
"""Initialize GatherV2"""
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
def __check__(self, params, indices, axis):
@ -889,7 +888,7 @@ class SparseGatherV2(PrimitiveWithCheck):
@prim_attr_register
def __init__(self):
"""Initialize index_select"""
"""Initialize SparseGatherV2"""
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
@ -909,7 +908,7 @@ class Padding(PrimitiveWithInfer):
Extends the last dimension of the input tensor from 1 to pad_dim_size, by filling with 0.
Args:
pad_dim_size (int): The value of the last dimension of `x` to be extended, which must be positive.
pad_dim_size (int): The value of the last dimension of `x` to be extended, which must be positive. Default: 8.
Inputs:
- **x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. The rank of `x` must be at least 2.
@ -1618,7 +1617,7 @@ class InvertPermutation(PrimitiveWithInfer):
if mstype.issubclass_(x['dtype'], mstype.tensor):
raise ValueError(f'For \'{self.name}\' the input value must be non-Tensor.')
for shp in x_shp:
if shp != []:
if shp:
x_rank = len(np.array(x_value, np.int64).shape)
raise ValueError(f'For \'{self.name}\' the rank of input must be 1, but got {x_rank}.')
for i, value in enumerate(x_value):
@ -2055,7 +2054,7 @@ class UnsortedSegmentSum(PrimitiveWithInfer):
validator.check_positive_int(segment_ids_shp_len, "rank of segment_ids", self.name)
validator.check(f'rank of input_x', len(x_shp),
'rank of segments_id', len(segment_ids_shp), Rel.GE, self.name)
if (not -1 in x_shp and not -1 in segment_ids_shp):
if -1 not in x_shp and -1 not in segment_ids_shp:
# only validate when both shapes fully known
for i, value in enumerate(segment_ids_shp):
validator.check("ids[%d]" % i, value, 'input[%d]' % i, x_shp[i], Rel.EQ, self.name)
@ -2145,7 +2144,7 @@ class UnsortedSegmentMin(PrimitiveWithCheck):
validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name)
num_segments_type = num_segments['dtype']
validator.check_subclass("num_segments", num_segments_type, [mstype.number], self.name)
if (not -1 in x_shape and not -1 in segment_ids_shape):
if -1 not in x_shape and -1 not in segment_ids_shape:
# only validate when both shapes fully known
validator.check(f'first shape of input_x', x_shape[0],
'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name)
@ -2205,7 +2204,7 @@ class UnsortedSegmentMax(PrimitiveWithCheck):
validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name)
num_segments_type = num_segments['dtype']
validator.check_subclass("num_segments", num_segments_type, [mstype.number], self.name)
if (not -1 in x_shape and not -1 in segment_ids_shape):
if -1 not in x_shape and -1 not in segment_ids_shape:
# only validate when both shapes fully known
validator.check(f'first shape of input_x', x_shape[0],
'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name)
@ -2670,20 +2669,20 @@ class Slice(PrimitiveWithInfer):
>>> data = Tensor(np.array([[[1, 1, 1], [2, 2, 2]],
... [[3, 3, 3], [4, 4, 4]],
... [[5, 5, 5], [6, 6, 6]]]).astype(np.int32))
>>> slice = ops.Slice()
>>> output = slice(data, (1, 0, 0), (1, 1, 3))
>>> slice_op = ops.Slice()
>>> output = slice_op(data, (1, 0, 0), (1, 1, 3))
>>> print(output)
[[[3 3 3]]]
>>> output = slice(data, (1, 0, 0), (1, 1, 2))
>>> output = slice_op(data, (1, 0, 0), (1, 1, 2))
>>> print(output)
[[[3 3]]]
>>> output = slice(data, (1, 0, 0), (1, 1, 1))
>>> output = slice_op(data, (1, 0, 0), (1, 1, 1))
>>> print(output)
[[[3]]]
>>> output = slice(data, (1, 1, 0), (1, 1, 3))
>>> output = slice_op(data, (1, 1, 0), (1, 1, 3))
>>> print(output)
[[[4 4 4]]]
>>> output = slice(data, (1, 0, 1), (1, 1, 2))
>>> output = slice_op(data, (1, 0, 1), (1, 1, 2))
>>> print(output)
[[[3 3]]]
"""
@ -2755,6 +2754,7 @@ class ReverseV2(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, axis):
"""Initialize ReverseV2."""
validator.check_value_type('axis', axis, [list, tuple], self.name)
for i, each in enumerate(axis):
validator.check_value_type(f'axis[{i}]', each, [int], self.name)
@ -2816,6 +2816,7 @@ class Rint(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
"""Initialize Rint."""
self.init_prim_io_names(inputs=['x'], outputs=['output'])
def infer_shape(self, x_shape):
@ -2885,7 +2886,7 @@ class Select(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
"""init"""
"""Initialize Select."""
self.init_prim_io_names(inputs=['condition', 'x', 'y'], outputs=['output'])
def infer_shape(self, cond_shape, x_shape, y_shape):
@ -3068,9 +3069,9 @@ class StridedSlice(PrimitiveWithInfer):
>>> # [6,6,6]
>>> # ]
>>> # ]
>>> slice = ops.StridedSlice()
>>> output = slice(input_x, (1, 0, 2), (3, 1, 3), (1, 1, 1))
>>> # Take the call of operator " output = slice(input_x, (1, 0, 2), (3, 1, 3), (1, 1, 1)) " as an example,
>>> strided_slice = ops.StridedSlice()
>>> output = strided_slice(input_x, (1, 0, 2), (3, 1, 3), (1, 1, 1))
>>> # Take this " output = strided_slice(input_x, (1, 0, 2), (3, 1, 3), (1, 1, 1)) " as an example,
>>> # start = [1, 0, 2] , end = [3, 1, 3], stride = [1, 1, 1], Find a segment of (start, end),
>>> # note that end is an open interval
>>> # To facilitate understanding, this operator can be divided into three steps:
@ -3113,7 +3114,7 @@ class StridedSlice(PrimitiveWithInfer):
>>> # The final output after finishing is:
[[[3], [5]]]
>>> # anothor example like :
>>> output = slice(input_x, (1, 0, 0), (2, 1, 3), (1, 1, 1))
>>> output = strided_slice(input_x, (1, 0, 0), (2, 1, 3), (1, 1, 1))
>>> print(output)
[[[3. 3. 3.]]]
"""
@ -5269,7 +5270,7 @@ class Meshgrid(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, indexing="xy"):
"""Init Meshgrid"""
"""Initialize Meshgrid."""
validator.check_value_type("indexing", indexing, (str), self.name)
if indexing not in ("xy", "ij"):
raise ValueError("indexing parameter must be either 'xy' or 'ij'")
@ -5567,6 +5568,7 @@ class TransShape(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
"""Initialize TransShape."""
self.__setattr_flag__ = True
def __infer__(self, x, shape):
@ -5673,7 +5675,7 @@ class EmbeddingLookup(PrimitiveWithCheck):
@prim_attr_register
def __init__(self):
"""Initialize index_select"""
"""Initialize EmbeddingLookup."""
self.__setattr_flag__ = True
self.init_prim_io_names(inputs=['params', 'indices', 'offset'],
outputs=['output'])

View File

@ -109,6 +109,7 @@ class AllReduce(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP):
"""Initialize AllReduce."""
if not isinstance(op, type(ReduceOp.SUM)):
raise TypeError("The operation of AllReduce should be str.")
if not isinstance(_get_group(group), str):
@ -182,6 +183,7 @@ class AllGather(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, group=GlobalComm.WORLD_COMM_GROUP):
"""Initialize AllGather."""
validator.check_value_type('group', _get_group(group), (str,), self.name)
self.rank = get_rank(_get_group(group))
self.rank_size = get_group_size(_get_group(group))
@ -215,6 +217,7 @@ class _MiniStepAllGather(PrimitiveWithInfer):
"""
@prim_attr_register
def __init__(self, group=GlobalComm.WORLD_COMM_GROUP, grad_accumulation_step=None, mean_flag=None):
"""Initialize _MiniStepAllGather."""
validator.check_value_type('group', _get_group(group), (str,), self.name)
self.rank = get_rank(_get_group(group))
self.rank_size = get_group_size(_get_group(group))
@ -247,7 +250,7 @@ class _HostAllGather(PrimitiveWithInfer):
mpirun -output-filename log -merge-stderr-to-stdout -np 3 python test_host_all_gather.py
Args:
group (Union[tuple[int],list[int]]): The rand_ids of communication group to work on.
group (Union[tuple[int],list[int]]): The rand_ids of communication group to work on. Default: None.
Raises:
TypeError: If group is not a list nor tuple, or elements of group are not int.
@ -263,6 +266,7 @@ class _HostAllGather(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, group=None):
"""Initialize _HostAllGather."""
if group is None:
raise ValueError(f"For '{self.name}' group must be set.")
validator.check_value_type('group', group, (tuple, list), self.name)
@ -338,6 +342,7 @@ class ReduceScatter(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP):
"""Initialize ReduceScatter."""
validator.check_value_type('op', op, (type(ReduceOp.SUM),), self.name)
validator.check_value_type('group', _get_group(group), (str,), self.name)
self.op = op
@ -347,9 +352,11 @@ class ReduceScatter(PrimitiveWithInfer):
self.add_prim_attr('fusion', 0)
def infer_shape(self, x_shape):
if self.rank_size == 0:
raise ValueError(f"For '{self.name}' rank_size can not be zero.")
if x_shape[0] % self.rank_size != 0:
raise ValueError(f"For '{self.name}' the first dimension of x should be divided by rank_size.")
x_shape[0] = int(x_shape[0]/self.rank_size)
x_shape[0] = int(x_shape[0] / self.rank_size)
return x_shape
def infer_dtype(self, x_dtype):
@ -373,7 +380,7 @@ class _HostReduceScatter(PrimitiveWithInfer):
Args:
op (str): Specifies an operation used for element-wise reductions,
like sum, max, avg. Default: ReduceOp.SUM.
group (Union[tuple[int],list[int]]): The rand_ids of communication group to work on.
group (Union[tuple[int],list[int]]): The rand_ids of communication group to work on. Default: None.
Raises:
TypeError: If op is not a string and group is not a list nor tuple,
@ -383,6 +390,7 @@ class _HostReduceScatter(PrimitiveWithInfer):
"""
@prim_attr_register
def __init__(self, op=ReduceOp.SUM, group=None):
"""Initialize _HostReduceScatter."""
if group is None:
raise ValueError(f"For '{self.name}' group must be set.")
validator.check_value_type('op', op, (type(ReduceOp.SUM),), self.name)
@ -398,7 +406,7 @@ class _HostReduceScatter(PrimitiveWithInfer):
def infer_shape(self, x_shape):
if x_shape[0] % self.group_size != 0:
raise ValueError(f"For '{self.name}' the first dimension of x should be divided by group_size.")
x_shape[0] = int(x_shape[0]/self.group_size)
x_shape[0] = int(x_shape[0] / self.group_size)
return x_shape
def infer_dtype(self, x_dtype):
@ -465,6 +473,7 @@ class Broadcast(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, root_rank, group=GlobalComm.WORLD_COMM_GROUP):
"""Initialize Broadcast."""
validator.check_value_type('root_rank', root_rank, (int,), self.name)
validator.check_value_type('group', _get_group(group), (str,), self.name)
check_hcom_group_valid(group)
@ -592,6 +601,7 @@ class _MirrorOperator(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, group=None, dev_num=None, mean_flag=None):
"""Initialize _MirrorOperator."""
self.group = group
self.dev_num = dev_num
self.mean_flag = mean_flag
@ -621,6 +631,7 @@ class _MirrorMiniStepOperator(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, group=None, dev_num=None, mean_flag=None, grad_accumulation_step=None):
"""Initialize _MirrorMiniStepOperator."""
self.group = group
self.dev_num = dev_num
self.mean_flag = mean_flag
@ -645,6 +656,7 @@ class _VirtualDiv(PrimitiveWithInfer):
"""
@prim_attr_register
def __init__(self, divisor=None):
"""Initialize _VirtualDiv."""
self.divisor = divisor
def infer_shape(self, x_shape):
@ -661,7 +673,7 @@ class _VirtualAdd(PrimitiveWithInfer):
"""Auto parallel virtual operator. Do nothing in forward, do Add in backward."""
@prim_attr_register
def __init__(self):
"""init"""
"""Initialize _VirtualAdd."""
def infer_shape(self, x_shape, y_shape):
return x_shape
@ -679,7 +691,7 @@ class _VirtualDataset(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
"""init"""
"""Initialize _VirtualDataset."""
def infer_shape(self, *args):
return args
@ -696,12 +708,10 @@ class _VirtualAssignAdd(PrimitiveWithInfer):
Auto parallel virtual operator. Do nothing in forward, do AssignAdd in backward. It is only for
internal use of parallel modules and cannot be called by users.
Args:
micro (int): MicroBatch. Default: 0.
"""
@prim_attr_register
def __init__(self):
"""init"""
"""Initialize _VirtualAssignAdd."""
def infer_shape(self, x_shape, y_shape):
return x_shape
@ -720,7 +730,7 @@ class _VirtualAccuGrad(PrimitiveWithInfer):
"""
@prim_attr_register
def __init__(self):
"""init"""
"""Initialize _VirtualAccuGrad."""
def infer_shape(self, x_shape, y_shape):
return x_shape
@ -745,6 +755,7 @@ class _MirrorMicroStepOperator(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, group=None, dev_num=None, mean_flag=None):
"""Initialize _MirrorMicroStepOperator."""
self.group = group
self.dev_num = dev_num
self.mean_flag = mean_flag
@ -765,7 +776,7 @@ class _VirtualOutput(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
"""init"""
"""Initialize _VirtualOutput."""
def infer_shape(self, x_shape):
return x_shape
@ -784,7 +795,7 @@ class _GetTensorSlice(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
"""Initialize ChunkTensor"""
"""Initialize _GetTensorSlice."""
def infer_value(self, x, dev_mat, tensor_map):
from mindspore.parallel._tensor import _load_tensor

View File

@ -68,20 +68,20 @@ class GeSwitch(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
"""init"""
"""Initialize GeSwitch."""
def __call__(self, data, pred):
raise NotImplementedError
def infer_shape(self, data, pred):
validator.check_equal_int(len(pred), 0, "pred rank", self.name)
return (data, data)
return data, data
def infer_dtype(self, data_type, pred_type):
validator.check_subclass(
"data", data_type, (mstype.tensor,) + mstype.number_type, self.name)
validator.check_tensor_dtype_valid("pred", pred_type, [mstype.bool_], self.name)
return (data_type, data_type)
return data_type, data_type
class Merge(PrimitiveWithInfer):
@ -108,13 +108,13 @@ class Merge(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
"""init"""
"""Initialize Merge."""
def __call__(self, *args):
raise NotImplementedError
def infer_shape(self, inputs):
return (inputs[0], [1])
return inputs[0], [1]
def infer_dtype(self, inputs):
args = {}
@ -122,4 +122,4 @@ class Merge(PrimitiveWithInfer):
args['inputs[%d]' % i] = item
validator.check_scalar_or_tensor_types_same(args, (mstype.bool_,) + mstype.number_type, self.name)
return (inputs[0], mstype.int32)
return inputs[0], mstype.int32

View File

@ -86,7 +86,7 @@ class ScalarSummary(Primitive):
@prim_attr_register
def __init__(self):
"""init"""
"""Initialize ScalarSummary."""
self.add_prim_attr("side_effect_io", True)
@ -125,7 +125,7 @@ class ImageSummary(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
"""init"""
"""Initialize ImageSummary."""
self.add_prim_attr("side_effect_io", True)
def __infer__(self, name, value):
@ -177,7 +177,7 @@ class TensorSummary(Primitive):
@prim_attr_register
def __init__(self):
"""init"""
"""Initialize TensorSummary."""
self.add_prim_attr("side_effect_io", True)
@ -217,7 +217,7 @@ class HistogramSummary(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
"""init"""
"""Initialize HistogramSummary."""
self.add_prim_attr("side_effect_io", True)
def __infer__(self, name, value):
@ -285,6 +285,7 @@ class InsertGradientOf(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, f):
"""Initialize InsertGradientOf."""
self.add_prim_attr('side_effect_backprop', True)
self.f = f
@ -337,6 +338,7 @@ class HookBackward(PrimitiveWithInfer):
"""
def __init__(self, hook_fn, cell_id=""):
"""Initialize HookBackward."""
super(HookBackward, self).__init__(self.__class__.__name__)
self.add_prim_attr("cell_id", cell_id)
self.init_attrs["cell_id"] = cell_id
@ -398,6 +400,7 @@ class Print(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
"""Initialize Print."""
self.add_prim_attr("side_effect_io", True)
def __call__(self, *args):

View File

@ -173,6 +173,7 @@ class TensorAdd(_MathBinaryOp):
@deprecated("1.1", "Add", True)
@prim_attr_register
def __init__(self):
"""Initialize TensorAdd."""
_MathBinaryOp.__init__(self)
def infer_value(self, x, y):
@ -310,7 +311,7 @@ class _Reduce(PrimitiveWithInfer):
Args:
keep_dims (bool): If true, keep these reduced dimensions and the length is 1.
If false, don't keep these dimensions.
If false, don't keep these dimensions. Default: False.
"""
__mindspore_signature__ = (
@ -603,7 +604,7 @@ class ReduceMax(_Reduce):
@prim_attr_register
def __init__(self, keep_dims=False):
"""ReduceMax"""
"""Initialize ReduceMax."""
super(ReduceMax, self).__init__(keep_dims)
self.__setattr_flag__ = True
@ -745,6 +746,7 @@ class CumProd(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, exclusive=False, reverse=False):
"""Initialize CumProd."""
cls_name = self.name
self.exclusive = validator.check_value_type("exclusive", exclusive, [bool], cls_name)
self.reverse = validator.check_value_type("reverse", reverse, [bool], cls_name)
@ -803,6 +805,7 @@ class MatMul(PrimitiveWithCheck):
@prim_attr_register
def __init__(self, transpose_a=False, transpose_b=False):
"""Initialize MatMul."""
self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output'])
cls_name = self.name
validator.check_value_type("transpose_a", transpose_a, [bool], cls_name)
@ -908,6 +911,7 @@ class BatchMatMul(MatMul):
@prim_attr_register
def __init__(self, transpose_a=False, transpose_b=False):
"""Initialize BatchMatMul."""
self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output'])
cls_name = self.name
validator.check_value_type("transpose_a", transpose_a, [bool], cls_name)
@ -1016,13 +1020,14 @@ class AddN(Primitive):
@prim_attr_register
def __init__(self):
"""Initialize AddN."""
self.init_prim_io_names(inputs=["inputs"], outputs=["sum"])
def check_elim(self, inputs):
if len(inputs) != 1:
return (False, None)
return False, None
if isinstance(inputs[0], Tensor):
return (True, inputs[0])
return True, inputs[0]
raise TypeError("Expecting Tensor, got : {}".format(type(inputs[0])))
@ -1068,14 +1073,15 @@ class AccumulateNV2(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
"""Initialize AccumulateNV2."""
self.__setattr_flag__ = True
self.init_prim_io_names(inputs=["inputs"], outputs=["sum"])
def check_elim(self, inputs):
if len(inputs) != 1:
return (False, None)
return False, None
if isinstance(inputs[0], Tensor):
return (True, inputs[0])
return True, inputs[0]
raise TypeError("Expecting Tensor, got : {}".format(type(inputs[0])))
def infer_shape(self, inputs):
@ -1732,7 +1738,7 @@ class Expm1(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
"""Initialize Exp"""
"""Initialize Expm1."""
self.init_prim_io_names(inputs=['x'], outputs=['y'])
def infer_shape(self, x_shape):
@ -1770,15 +1776,16 @@ class HistogramFixedWidth(PrimitiveWithInfer):
Examples:
>>> x = Tensor([-1.0, 0.0, 1.5, 2.0, 5.0, 15], mindspore.float16)
>>> range = Tensor([0.0, 5.0], mindspore.float16)
>>> range_op = Tensor([0.0, 5.0], mindspore.float16)
>>> hist = ops.HistogramFixedWidth(5)
>>> output = hist(x, range)
>>> output = hist(x, range_op)
>>> print(output)
[2 1 1 0 2]
"""
@prim_attr_register
def __init__(self, nbins, dtype='int32'):
"""Initialize HistogramFixedWidth."""
self.nbins = validator.check_value_type("nbins", nbins, [int], self.name)
validator.check_int(nbins, 1, Rel.GE, "nbins", self.name)
valid_values = ['int32', 'int64']
@ -1825,6 +1832,7 @@ class Log(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
"""Initialize Log."""
self.init_prim_io_names(inputs=['x'], outputs=['y'])
def infer_shape(self, x):
@ -1870,6 +1878,7 @@ class Log1p(Primitive):
@prim_attr_register
def __init__(self):
"""Initialize Log1p."""
self.init_prim_io_names(inputs=['x'], outputs=['y'])
@ -2437,6 +2446,7 @@ class Floor(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
"""Initialize Floor."""
self.init_prim_io_names(inputs=['x'], outputs=['y'])
def infer_shape(self, x_shape):
@ -2515,6 +2525,7 @@ class Ceil(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
"""Initialize Ceil."""
self.init_prim_io_names(inputs=['x'], outputs=['y'])
def infer_shape(self, x_shape):
@ -3782,11 +3793,11 @@ class NMSWithMask(PrimitiveWithInfer):
validator.check_positive_int(bboxes_shape[0], "bboxes.shape[0]", cls_name)
validator.check_equal_int(bboxes_shape[1], 5, "bboxes.shape[1]", cls_name)
num = bboxes_shape[0]
return (bboxes_shape, (num,), (num,))
return bboxes_shape, (num,), (num,)
def infer_dtype(self, bboxes_dtype):
validator.check_tensor_dtype_valid("bboxes", bboxes_dtype, [mstype.float16, mstype.float32], self.name)
return (bboxes_dtype, mstype.int32, mstype.bool_)
return bboxes_dtype, mstype.int32, mstype.bool_
class Abs(PrimitiveWithInfer):
@ -3811,8 +3822,8 @@ class Abs(PrimitiveWithInfer):
Examples:
>>> input_x = Tensor(np.array([-1.0, 1.0, 0.0]), mindspore.float32)
>>> abs = ops.Abs()
>>> output = abs(input_x)
>>> abs_op = ops.Abs()
>>> output = abs_op(input_x)
>>> print(output)
[1. 1. 0.]
"""
@ -3896,8 +3907,8 @@ class Round(PrimitiveWithInfer):
Examples:
>>> input_x = Tensor(np.array([0.8, 1.5, 2.3, 2.5, -4.5]), mindspore.float32)
>>> round = ops.Round()
>>> output = round(input_x)
>>> round_op = ops.Round()
>>> output = round_op(input_x)
>>> print(output)
[ 1. 2. 2. 2. -4.]
"""

View File

@ -172,6 +172,7 @@ class AdaptiveAvgPool2D(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, output_size):
"""Initialize AdaptiveAvgPool2D."""
validator.check_value_type("output_size", output_size, [int, tuple], self.name)
if isinstance(output_size, tuple):
validator.check_int(len(output_size), 2, Rel.EQ, 'output_size', self.name)
@ -187,7 +188,7 @@ class AdaptiveAvgPool2D(PrimitiveWithInfer):
out_size = [i if i else j for i, j in zipped]
for item in out_size:
validator.check_value_type("item of output_size", item, [int], self.name)
self.add_prim_attr('output_size', (out_size))
self.add_prim_attr('output_size', out_size)
output_shape = x_shape[:len(x_shape) - len(out_size)] + out_size
return output_shape
@ -238,6 +239,7 @@ class Softmax(Primitive):
@prim_attr_register
def __init__(self, axis=-1):
"""Initialize Softmax."""
self.init_prim_io_names(inputs=['x'], outputs=['output'])
validator.check_value_type("axis", axis, [int, tuple], self.name)
if isinstance(axis, int):
@ -286,6 +288,7 @@ class LogSoftmax(Primitive):
@prim_attr_register
def __init__(self, axis=-1):
"""Initialize LogSoftmax."""
validator.check_value_type("axis", axis, [int], self.name)
@ -684,6 +687,7 @@ class HSwish(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
"""Initialize HSwish."""
self.init_prim_io_names(inputs=['x'], outputs=['output'])
def infer_shape(self, xshape):
@ -728,6 +732,7 @@ class Sigmoid(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
"""Initialize Sigmoid."""
self.init_prim_io_names(inputs=['x'], outputs=['output'])
def infer_shape(self, input_x):
@ -774,6 +779,7 @@ class HSigmoid(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
"""Initialize HSigmoid."""
self.init_prim_io_names(inputs=['x'], outputs=['output'])
def infer_shape(self, x_shape):
@ -865,7 +871,6 @@ class InstanceNorm(PrimitiveWithInfer):
momentum (float): The hyper parameter to compute moving average for running_mean and running_var
(e.g. :math:`new\_running\_mean = momentum * running\_mean + (1 - momentum) * current\_mean`).
Momentum value must be [0, 1]. Default: 0.1.
data_format (str): The optional value for data format, is 'NCHW'. Default: "NCHW".
Inputs:
- **input_x** (Tensor) - The input of InstanceNorm, Tensor of shape :math:`(N, C)`,
@ -928,6 +933,7 @@ class InstanceNorm(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, epsilon=1e-5, momentum=0.1):
"""Initialize InstanceNorm."""
self.init_prim_io_names(inputs=['x', 'gamma', 'beta', 'mean', 'variance'],
outputs=['y', 'save_mean', 'save_variance'])
self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name)
@ -945,7 +951,7 @@ class InstanceNorm(PrimitiveWithInfer):
validator.check("mean shape", mean, "gamma shape", gamma, Rel.EQ, self.name)
save_mean_shape = gamma
save_mean_shape[0] = save_mean_shape[0] * input_shape_norm[0]
return (input_x, save_mean_shape, save_mean_shape)
return input_x, save_mean_shape, save_mean_shape
def infer_dtype(self, input_x, gamma, beta, mean, variance):
validator.check_tensor_dtype_valid("input_x", input_x, [mstype.float16, mstype.float32], self.name)
@ -954,7 +960,7 @@ class InstanceNorm(PrimitiveWithInfer):
args_moving = {"mean": mean, "variance": variance}
valid_dtypes = [mstype.tensor_type(mstype.float32)]
validator.check_types_same_and_valid(args_moving, valid_dtypes, self.name)
return (input_x, gamma, gamma)
return input_x, gamma, gamma
class BNTrainingReduce(PrimitiveWithInfer):
@ -992,15 +998,16 @@ class BNTrainingReduce(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
"""Initialize BNTrainingReduce."""
self.init_prim_io_names(inputs=['x'], outputs=['sum', 'square_sum'])
def infer_shape(self, x_shape):
validator.check_equal_int(len(x_shape), 4, "x rank", self.name)
return ([x_shape[1]], [x_shape[1]])
return [x_shape[1]], [x_shape[1]]
def infer_dtype(self, x_type):
validator.check_tensor_dtype_valid("x", x_type, [mstype.float16, mstype.float32], self.name)
return (x_type, x_type)
return x_type, x_type
class BNTrainingUpdate(PrimitiveWithInfer):
@ -1076,6 +1083,7 @@ class BNTrainingUpdate(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, isRef=True, epsilon=1e-5, factor=0.1):
"""Initialize BNTrainingUpdate."""
self.init_prim_io_names(inputs=['x', 'sum', 'square_sum', 'scale', 'b', 'mean', 'variance'],
outputs=['y', 'running_mean', 'running_variance', 'save_mean', 'save_inv_variance'])
validator.check_value_type("isRef", isRef, [bool], self.name)
@ -1098,14 +1106,14 @@ class BNTrainingUpdate(PrimitiveWithInfer):
validator.check("offset shape", b[0], "x_shape[1]", x[1], Rel.EQ, self.name)
validator.check("mean shape", mean[0], "x_shape[1]", x[1], Rel.EQ, self.name)
validator.check("variance shape", variance[0], "x_shape[1]", x[1], Rel.EQ, self.name)
return (x, variance, variance, variance, variance)
return x, variance, variance, variance, variance
def infer_dtype(self, x, sum, square_sum, scale, b, mean, variance):
tuple(map(partial(validator.check_tensor_dtype_valid,
valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
("x", "sum", "square_sum", "scale", "b", "mean", "variance"),
(x, sum, square_sum, scale, b, mean, variance)))
return (x, variance, variance, variance, variance)
return x, variance, variance, variance, variance
class BatchNorm(PrimitiveWithInfer):
@ -1203,6 +1211,7 @@ class BatchNorm(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, is_training=False, epsilon=1e-5, momentum=0.1, data_format="NCHW"):
"""Initialize BatchNorm."""
if is_training is False:
self.set_signatures(tuple())
validator.check_value_type('is_training', is_training, (bool,), self.name)
@ -1224,13 +1233,13 @@ class BatchNorm(PrimitiveWithInfer):
validator.check_equal_int(len(mean), 1, "mean rank", self.name)
validator.check("mean shape", mean, "variance shape", variance, Rel.EQ, self.name)
validator.check("mean shape", mean, "scale shape", scale, Rel.EQ, self.name)
return (input_x, scale, scale, scale, scale)
return input_x, scale, scale, scale, scale
def infer_dtype(self, input_x, scale, bias, mean, variance):
validator.check_tensor_dtype_valid("input_x", input_x, [mstype.float16, mstype.float32], self.name)
args = {"scale": scale, "bias": bias, "mean": mean, "variance": variance}
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
return (input_x, mstype.float32, mstype.float32, mstype.float32, mstype.float32)
return input_x, mstype.float32, mstype.float32, mstype.float32, mstype.float32
class Conv2D(Primitive):
@ -1524,6 +1533,7 @@ class _Pool(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, kernel_size=1, strides=1, pad_mode="valid", data_format="NCHW"):
"""Initialize _Pool."""
self.init_prim_io_names(inputs=['x'], outputs=['output'])
validator.check_value_type('kernel_size', kernel_size, [int, tuple], self.name)
validator.check_value_type('strides', strides, [int, tuple], self.name)
@ -1642,6 +1652,7 @@ class MaxPool(_Pool):
@prim_attr_register
def __init__(self, kernel_size=1, strides=1, pad_mode="valid", data_format="NCHW"):
"""Initialize MaxPool."""
super(MaxPool, self).__init__(kernel_size, strides, pad_mode, data_format)
@ -1710,6 +1721,7 @@ class MaxPoolWithArgmax(_Pool):
@prim_attr_register
def __init__(self, kernel_size=1, strides=1, pad_mode="valid", data_format="NCHW"):
"""Initialize MaxPoolWithArgmax."""
super(MaxPoolWithArgmax, self).__init__(kernel_size, strides, pad_mode, data_format)
def infer_shape(self, x_shape):
@ -1784,6 +1796,7 @@ class MaxPool3D(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, kernel_size=1, strides=1, pad_mode="VALID", data_format="NCDHW"):
"""Initialize MaxPool3D."""
self.init_prim_io_names(inputs=['x'], outputs=['output'])
validator.check_value_type('kernel_size', kernel_size, [int, tuple], self.name)
validator.check_value_type('strides', strides, [int, tuple], self.name)
@ -1900,6 +1913,7 @@ class AvgPool(_Pool):
@prim_attr_register
def __init__(self, kernel_size=1, strides=1, pad_mode="valid", data_format="NCHW"):
"""Initialize AvgPool."""
super(AvgPool, self).__init__(kernel_size, strides, pad_mode, data_format)
@ -2074,6 +2088,7 @@ class Conv2DTranspose(Conv2DBackpropInput):
@prim_attr_register
def __init__(self, out_channel, kernel_size, pad_mode="valid", pad=0,
pad_list=None, mode=1, stride=1, dilation=1, group=1, data_format="NCHW"):
"""Initialize Conv2DTranspose."""
super(Conv2DTranspose, self).__init__(out_channel, kernel_size, pad_mode, pad,
pad_list, mode, stride, dilation, group, data_format)
@ -2118,6 +2133,7 @@ class BiasAdd(Primitive):
@prim_attr_register
def __init__(self, data_format="NCHW"):
"""Initialize BiasAdd."""
self.init_prim_io_names(inputs=['x', 'b'], outputs=['output'])
self.format = validator.check_string(data_format, ['NCHW', 'NHWC', 'NCDHW'], 'format', self.name)
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
@ -2177,6 +2193,7 @@ class TopK(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, sorted=False):
"""Initialize TopK."""
validator.check_value_type("sorted", sorted, [bool], self.name)
self.init_prim_io_names(inputs=['input', 'k'],
outputs=['values', 'indices'])
@ -2340,12 +2357,12 @@ class SoftmaxCrossEntropyWithLogits(PrimitiveWithInfer):
validator.check("logits_shape", logits_shape, "labels_shape", labels_shape, Rel.EQ, self.name)
loss_shape = [logits_shape[0]]
dlogits_shape = logits_shape
return (loss_shape, dlogits_shape)
return loss_shape, dlogits_shape
def infer_dtype(self, logits_type, labels_type):
args = {"logits": logits_type, "labels": labels_type}
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
return (logits_type, logits_type)
return logits_type, logits_type
class SparseSoftmaxCrossEntropyWithLogits(PrimitiveWithInfer):
@ -2401,6 +2418,7 @@ class SparseSoftmaxCrossEntropyWithLogits(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, is_grad=False):
"""Initialize SparseSoftmaxCrossEntropyWithLogits."""
validator.check_value_type('is_grad', is_grad, [bool], self.name)
self.init_prim_io_names(inputs=['features', 'labels'], outputs=['output'])
self.is_grad = is_grad
@ -2473,6 +2491,7 @@ class ApplyMomentum(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, use_nesterov=False, use_locking=False, gradient_scale=1.0):
"""Initialize ApplyMomentum."""
self.use_nesterov = validator.check_bool(use_nesterov)
self.use_locking = validator.check_bool(use_locking)
validator.check_value_type('gradient_scale', gradient_scale, [float], self.name)
@ -2543,6 +2562,7 @@ class SmoothL1Loss(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, beta=1.0):
"""Initialize SmoothL1Loss."""
validator.check_value_type('beta', beta, [float], self.name)
validator.check('beta', beta, '', 0, Rel.GT, self.name)
self.init_prim_io_names(inputs=['prediction', 'target'], outputs=['output'])
@ -2636,6 +2656,7 @@ class DataFormatDimMap(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, src_format='NHWC', dst_format='NCHW'):
"""Initialize DataFormatDimMap."""
valid_values = ['NHWC', 'NCHW']
self.src_format = validator.check_string(src_format, valid_values, "src_format", self.name)
self.dst_format = validator.check_string(dst_format, valid_values, "dst_format", self.name)
@ -2692,6 +2713,7 @@ class RNNTLoss(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, blank_label=0):
"""Initialize RNNTLoss."""
validator.check_value_type('blank_label', blank_label, [int], self.name)
self.init_prim_io_names(inputs=['acts', 'labels', 'input_length', 'label_length'],
outputs=['costs', 'grads'])
@ -2706,7 +2728,7 @@ class RNNTLoss(PrimitiveWithInfer):
validator.check('input_length size', input_length_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name)
validator.check('label_length size', label_length_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name)
costs_shape = (acts_shape[0],)
return (costs_shape, acts_shape)
return costs_shape, acts_shape
def infer_dtype(self, acts_type, labels_type, input_length_type, label_length_type):
validator.check_tensor_dtype_valid("acts_type", acts_type, [mstype.float32, mstype.float16], self.name)
@ -2714,7 +2736,7 @@ class RNNTLoss(PrimitiveWithInfer):
valid_dtypes=(mstype.int32,), prim_name=self.name),
("labels", "input_length", "label_length"),
(labels_type, input_length_type, label_length_type)))
return (acts_type, acts_type)
return acts_type, acts_type
class SGD(PrimitiveWithCheck):
@ -2771,6 +2793,7 @@ class SGD(PrimitiveWithCheck):
@prim_attr_register
def __init__(self, dampening=0.0, weight_decay=0.0, nesterov=False):
"""Initialize SGD."""
validator.check_value_type("nesterov", nesterov, [bool], self.name)
if nesterov and dampening != 0:
raise ValueError(f"Nesterov need zero dampening!")
@ -2863,6 +2886,7 @@ class ApplyRMSProp(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, use_locking=False):
"""Initialize ApplyRMSProp."""
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
self.init_prim_io_names(inputs=['var', 'mean_square', 'moment', 'learning_rate', 'grad',
'rho', 'momentum', 'epsilon'], outputs=['output'])
@ -2969,6 +2993,7 @@ class ApplyCenteredRMSProp(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, use_locking=False):
"""Initialize ApplyCenteredRMSProp."""
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
self.add_prim_attr('side_effect_mem', True)
@ -3056,6 +3081,7 @@ class LayerNorm(Primitive):
@prim_attr_register
def __init__(self, begin_norm_axis=1, begin_params_axis=1, epsilon=1e-7):
"""Initialize LayerNorm."""
validator.check_value_type('begin_norm_axis', begin_norm_axis, [int], self.name)
validator.check_value_type('begin_params_axis', begin_params_axis, [int], self.name)
validator.check_value_type('epsilon', epsilon, [float], self.name)
@ -3102,6 +3128,7 @@ class L2Normalize(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, axis=0, epsilon=1e-4):
"""Initialize L2Normalize."""
axis = [axis] if isinstance(axis, int) else axis
validator.check_value_type('axis', axis, [list, tuple], self.name)
validator.check_value_type('epsilon', epsilon, [int, float], self.name)
@ -3163,6 +3190,7 @@ class DropoutGenMask(Primitive):
@prim_attr_register
def __init__(self, Seed0=0, Seed1=0):
"""Initialize DropoutGenMask."""
self.init_prim_io_names(inputs=['shape', 'keep_prob'], outputs=['output'])
validator.check_value_type("Seed0", Seed0, [int], self.name)
validator.check_value_type("Seed1", Seed1, [int], self.name)
@ -3264,6 +3292,7 @@ class ResizeBilinear(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, size, align_corners=False):
"""Initialize ResizeBilinear."""
validator.check_value_type("size", size, [tuple, list], self.name)
validator.check_equal_int(len(size), 2, "size len", self.name)
for item in size:
@ -3337,6 +3366,7 @@ class OneHot(Primitive):
@prim_attr_register
def __init__(self, axis=-1):
"""Initialize OneHot."""
self.init_prim_io_names(inputs=['indices', 'depth', 'on_value', 'off_value'], outputs=['output'])
validator.check_value_type("axis", axis, [int], self.name)
@ -3412,7 +3442,7 @@ class FastGelu(PrimitiveWithInfer):
@deprecated("1.1", "FastGeLU", True)
@prim_attr_register
def __init__(self):
"""init FastGelu"""
"""Initialize FastGelu."""
self.init_prim_io_names(inputs=['x'], outputs=['output'])
def infer_shape(self, input_x):
@ -3457,7 +3487,7 @@ class FastGeLU(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
"""init FastGeLU"""
"""Initialize FastGeLU."""
self.init_prim_io_names(inputs=['x'], outputs=['output'])
def infer_shape(self, input_x):
@ -3509,6 +3539,7 @@ class GetNext(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, types, shapes, output_num, shared_name):
"""Initialize GetNext."""
validator.check_value_type("types", types, [list, tuple], self.name)
validator.check_value_type("shapes", shapes, [list, tuple], self.name)
validator.check("types length", len(types), "shapes length", len(shapes), Rel.EQ, self.name)
@ -3674,6 +3705,7 @@ class LSTM(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
"""Initialize LSTM."""
self.input_size = validator.check_positive_int(input_size, "input_size", self.name)
self.hidden_size = validator.check_positive_int(hidden_size, "hidden_size", self.name)
self.num_layers = validator.check_positive_int(num_layers, "num_layers", self.name)
@ -3704,12 +3736,12 @@ class LSTM(PrimitiveWithInfer):
# set arbitrary shape for reserved space
reserved_shape = (1, 1)
state_shape = (1, 1)
return (y_shape, h_shape, c_shape, reserved_shape, state_shape)
return y_shape, h_shape, c_shape, reserved_shape, state_shape
def infer_dtype(self, x_dtype, h_dtype, c_dtype, w_dtype):
args = {'x': x_dtype, 'h': h_dtype, 'c': c_dtype, 'w': w_dtype}
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float32, mstype.float16), self.name)
return (x_dtype, x_dtype, x_dtype, x_dtype, x_dtype)
return x_dtype, x_dtype, x_dtype, x_dtype, x_dtype
class SigmoidCrossEntropyWithLogits(PrimitiveWithInfer):
@ -4023,7 +4055,7 @@ class ComputeAccidentalHits(PrimitiveWithCheck):
the weight is -FLOAT_MAX. FLOAT_MAX indicates the max value in the type of Float
Args:
num_true (int): The number of target classes per training example.
num_true (int): The number of target classes per training example. Default: 1.
Inputs:
- **true_classes** (Tensor) - The target classes. With data type of int32 or int64
@ -4252,6 +4284,7 @@ class Adam(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, use_locking=False, use_nesterov=False):
"""Initialize Adam."""
validator.check_value_type("use_locking", use_locking, [bool], self.name)
validator.check_value_type("use_nesterov", use_nesterov, [bool], self.name)
self.add_prim_attr('side_effect_mem', True)
@ -4368,6 +4401,7 @@ class AdamNoUpdateParam(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, use_locking=False, use_nesterov=False):
"""Initialize AdamNoUpdateParam."""
validator.check_value_type("use_locking", use_locking, [bool], self.name)
validator.check_value_type("use_nesterov", use_nesterov, [bool], self.name)
@ -4501,6 +4535,7 @@ class FusedSparseAdam(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, use_locking=False, use_nesterov=False):
"""Initialize FusedSparseAdam."""
validator.check_value_type("use_locking", use_locking, [bool], self.name)
validator.check_value_type("use_nesterov", use_nesterov, [bool], self.name)
self.init_prim_io_names(inputs=['var', 'm', 'v', 'beta1_power', 'beta2_power', 'lr', 'beta1', 'beta2',
@ -4649,6 +4684,7 @@ class FusedSparseLazyAdam(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, use_locking=False, use_nesterov=False):
"""Initialize FusedSparseLazyAdam."""
validator.check_value_type("use_locking", use_locking, [bool], self.name)
validator.check_value_type("use_nesterov", use_nesterov, [bool], self.name)
self.init_prim_io_names(inputs=['var', 'm', 'v', 'beta1_power', 'beta2_power', 'lr', 'beta1', 'beta2',
@ -4762,6 +4798,7 @@ class FusedSparseFtrl(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, lr, l1, l2, lr_power, use_locking=False):
"""Initialize FusedSparseFtrl."""
self.init_prim_io_names(inputs=['var', 'accum', 'linear', 'grad', 'indices'],
outputs=['output'])
self.add_prim_attr('side_effect_mem', True)
@ -4878,6 +4915,7 @@ class FusedSparseProximalAdagrad(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, use_locking=False):
"""Initialize FusedSparseProximalAdagrad"""
self.init_prim_io_names(inputs=['var', 'accum', 'lr', 'l1', 'l2', 'grad', 'indices'],
outputs=['output'])
self.add_prim_attr('side_effect_mem', True)
@ -4968,6 +5006,7 @@ class KLDivLoss(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, reduction='mean'):
"""Initialize KLDivLoss."""
self.reduction = validator.check_string(reduction, ['none', 'mean', 'sum'], 'reduction', self.name)
def infer_shape(self, x_shape, y_shape):
@ -5055,6 +5094,7 @@ class BinaryCrossEntropy(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, reduction='mean'):
"""Initialize BinaryCrossEntropy."""
self.reduction = validator.check_string(reduction, ['none', 'mean', 'sum'], 'reduction', self.name)
def infer_shape(self, x_shape, y_shape, weight_shape):
@ -5442,6 +5482,7 @@ class ApplyAdagrad(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, update_slots=True):
"""Initialize ApplyAdagrad."""
validator.check_value_type("update_slots", update_slots, [bool], self.name)
self.add_prim_attr('side_effect_mem', True)
@ -5543,6 +5584,7 @@ class ApplyAdagradV2(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, epsilon, update_slots=True):
"""Initialize ApplyAdagradV2."""
validator.check_value_type("epsilon", epsilon, [float], self.name)
validator.check_value_type("update_slots", update_slots, [bool], self.name)
self.add_prim_attr('side_effect_mem', True)
@ -5644,6 +5686,7 @@ class SparseApplyAdagrad(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, lr, update_slots=True, use_locking=False):
"""Initialize SparseApplyAdagrad."""
validator.check_value_type("lr", lr, [float], self.name)
validator.check_is_float(lr, "lr", self.name)
validator.check_value_type("update_slots", update_slots, [bool], self.name)
@ -5748,6 +5791,7 @@ class SparseApplyAdagradV2(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, lr, epsilon, use_locking=False, update_slots=True):
"""Initialize SparseApplyAdagradV2."""
self.lr = validator.check_value_type("lr", lr, [float], self.name)
self.epsilon = validator.check_value_type("epsilon", epsilon, [float], self.name)
self.use_locking = validator.check_value_type("update_slots", update_slots, [bool], self.name)
@ -5860,6 +5904,7 @@ class ApplyProximalAdagrad(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, use_locking=False):
"""Initialize ApplyProximalAdagrad."""
self.init_prim_io_names(inputs=['var', 'accum', 'lr', 'l1', 'l2', 'grad'],
outputs=['var', 'accum'])
self.add_prim_attr('side_effect_mem', True)
@ -5985,6 +6030,7 @@ class SparseApplyProximalAdagrad(PrimitiveWithCheck):
@prim_attr_register
def __init__(self, use_locking=False):
"""Initialize SparseApplyProximalAdagrad."""
self.init_prim_io_names(inputs=['var', 'accum', 'lr', 'l1', 'l2', 'grad', 'indices'],
outputs=['var', 'accum'])
self.add_prim_attr('side_effect_mem', True)
@ -6095,7 +6141,7 @@ class ApplyAddSign(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
"Initialize ApplyAddSign"
"""Initialize ApplyAddSign."""
self.add_prim_attr('side_effect_mem', True)
def infer_shape(self, var_shape, m_shape, lr_shape, alpha_shape, sign_decay_shape,
@ -6225,7 +6271,7 @@ class ApplyPowerSign(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
"Initialize ApplyPowerSign"
"""Initialize ApplyPowerSign."""
self.add_prim_attr('side_effect_mem', True)
def infer_shape(self, var_shape, m_shape, lr_shape, logbase_shape, sign_decay_shape,
@ -6320,7 +6366,7 @@ class ApplyGradientDescent(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
"Initialize ApplyGradientDescent"
"""Initialize ApplyGradientDescent."""
self.add_prim_attr('side_effect_mem', True)
def infer_shape(self, var_shape, alpha_shape, delta_shape):
@ -6408,7 +6454,7 @@ class ApplyProximalGradientDescent(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
"Initialize ApplyGradientDescent"
"""Initialize ApplyGradientDescent."""
self.add_prim_attr('side_effect_mem', True)
def infer_shape(self, var_shape, alpha_shape, l1_shape, l2_shape, delta_shape):
@ -6495,7 +6541,7 @@ class LARSUpdate(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, epsilon=1e-05, hyperpara=0.001, use_clip=False):
"""init"""
"""Initialize LARSUpdate."""
validator.check_value_type("epsilon", epsilon, [float], self.name)
validator.check_value_type("hyperpara", hyperpara, [float], self.name)
validator.check_value_type("use_clip", use_clip, [bool], self.name)
@ -6601,6 +6647,7 @@ class ApplyFtrl(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, use_locking=False):
"""Initialize ApplyFtrl."""
self.init_prim_io_names(inputs=['var', 'accum', 'linear', 'grad', 'lr', 'l1', 'l2', 'lr_power'],
outputs=['output'])
self.add_prim_attr('side_effect_mem', True)
@ -6706,6 +6753,7 @@ class SparseApplyFtrl(PrimitiveWithCheck):
@prim_attr_register
def __init__(self, lr, l1, l2, lr_power, use_locking=False):
"""Initialize SparseApplyFtrl."""
validator.check_value_type("lr", lr, [float], self.name)
validator.check_value_type("l1", l1, [float], self.name)
validator.check_value_type("l2", l2, [float], self.name)
@ -6819,6 +6867,7 @@ class SparseApplyFtrlV2(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, lr, l1, l2, l2_shrinkage, lr_power, use_locking=False):
"""Initialize SparseApplyFtrlV2."""
validator.check_value_type("lr", lr, [float], self.name)
validator.check_value_type("l1", l1, [float], self.name)
validator.check_value_type("l2", l2, [float], self.name)
@ -6854,7 +6903,7 @@ class Dropout(PrimitiveWithCheck):
Args:
keep_prob (float): The keep rate, between 0 and 1, e.g. keep_prob = 0.9,
means dropping out 10% of input units.
means dropping out 10% of input units. Default: 0.5.
Seed0 (int): Seed0 value for random generating. Default: 0.
Seed1 (int): Seed1 value for random generating. Default: 0.
@ -6884,6 +6933,7 @@ class Dropout(PrimitiveWithCheck):
@prim_attr_register
def __init__(self, keep_prob=0.5, Seed0=0, Seed1=0):
"""Initialize Dropout."""
self.seed0 = validator.check_value_type("Seed0", Seed0, [int], self.name)
self.seed1 = validator.check_value_type("Seed1", Seed1, [int], self.name)
self.keep_prob = validator.check_float_range(keep_prob, 0, 1, Rel.INC_RIGHT, "keep_prob", self.name)
@ -6938,6 +6988,7 @@ class Dropout2D(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, keep_prob=0.5):
"""Initialize Dropout2D."""
self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
self.keep_prob = validator.check_float_range(keep_prob, 0.0, 1.0, Rel.INC_BOTH, "keep_prob", self.name)
@ -6995,6 +7046,7 @@ class Dropout3D(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, keep_prob=0.5):
"""Initialize Dropout3D."""
self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
self.keep_prob = validator.check_float_range(keep_prob, 0.0, 1.0, Rel.INC_BOTH, "keep_prob", self.name)
@ -7081,6 +7133,7 @@ class CTCLoss(Primitive):
@prim_attr_register
def __init__(self, preprocess_collapse_repeated=False, ctc_merge_repeated=True,
ignore_longer_outputs_than_inputs=False):
"""Initialize CTCLoss."""
self.init_prim_io_names(inputs=["inputs", "labels_indices", "labels_values", "sequence_length"],
outputs=["loss", "gradient"])
validator.check_value_type("preprocess_collapse_repeated", preprocess_collapse_repeated, [bool], self.name)
@ -7140,6 +7193,7 @@ class CTCGreedyDecoder(PrimitiveWithCheck):
@prim_attr_register
def __init__(self, merge_repeated=True):
"""Initialize CTCGreedyDecoder."""
self.merge_repeated = validator.check_value_type("merge_repeated", merge_repeated, [bool], self.name)
def check_shape(self, inputs_shape, sequence_length_shape):
@ -7169,6 +7223,7 @@ class BasicLSTMCell(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, keep_prob=1.0, forget_bias=1.0, state_is_tuple=True, activation='tanh'):
"""Initialize BasicLSTMCell."""
self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
self.keep_prob = validator.check_float_range(keep_prob, 0.0, 1.0, Rel.INC_BOTH, "keep_prob", self.name)
self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
@ -7195,7 +7250,7 @@ class BasicLSTMCell(PrimitiveWithInfer):
ot_shape = c_shape
tanhct_shape = c_shape
return (ct_shape, ht_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape)
return ct_shape, ht_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape
def infer_dtype(self, x_dtype, h_dtype, c_dtype, w_dtype, b_dtype):
tuple(map(partial(validator.check_tensor_dtype_valid,
@ -7204,7 +7259,7 @@ class BasicLSTMCell(PrimitiveWithInfer):
(x_dtype, h_dtype, w_dtype)))
args = {"c_dtype": c_dtype, "b_dtype": b_dtype}
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
return (c_dtype, mstype.float16, c_dtype, c_dtype, c_dtype, c_dtype, c_dtype)
return c_dtype, mstype.float16, c_dtype, c_dtype, c_dtype, c_dtype, c_dtype
class DynamicRNN(PrimitiveWithInfer):
@ -7315,6 +7370,7 @@ class DynamicRNN(PrimitiveWithInfer):
activation='tanh',
forget_bias=0.0,
is_training=True):
"""Initialize DynamicRNN."""
self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
self.cell_depth = validator.check_value_type("cell_depth", cell_depth, [int], self.name)
self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
@ -7480,6 +7536,7 @@ class DynamicGRUV2(PrimitiveWithInfer):
gate_order="rzh",
reset_after=True,
is_training=True):
"""Initialize DynamicGRUV2."""
self.cell_depth = validator.check_value_type("cell_depth", cell_depth, [int], self.name)
self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
self.cell_clip = validator.check_value_type("cell_clip", cell_clip, [float], self.name)
@ -7613,10 +7670,10 @@ class LRN(PrimitiveWithInfer):
\sum_{c'=\max(0, c-n/2)}^{\min(N-1,c+n/2)}a_{c'}^2\right)^{-\beta}
Args:
depth_radius (int): Half-width of the 1-D normalization window with the shape of 0-D.
bias (float): An offset (usually positive to avoid dividing by 0).
alpha (float): A scale factor, usually positive.
beta (float): An exponent.
depth_radius (int): Half-width of the 1-D normalization window with the shape of 0-D. Default: 5.
bias (float): An offset (usually positive to avoid dividing by 0). Default: 1.0.
alpha (float): A scale factor, usually positive. Default: 1.0.
beta (float): An exponent. Default: 0.5.
norm_region (str): Specifies normalization region. Options: "ACROSS_CHANNELS". Default: "ACROSS_CHANNELS".
Inputs:
@ -7801,12 +7858,12 @@ class Conv3D(PrimitiveWithInfer):
:math:`padding` is zero-padding added to both sides of the input.
Args:
out_channels (int): The number of output channel :math:`C_{out}`.
out_channel (int): The number of output channel :math:`C_{out}`.
kernel_size (Union[int, tuple[int]]): The data type is int or a tuple of 3 integers. Specifies the depth, height
and width of the 3D convolution window. Single int means the value is for the depth, height and the width
of the kernel. A tuple of 3 ints means the first value is for the depth, height and the other is for the
width of the kernel.
mode (int): Modes for different convolutions. It is currently not used
mode (int): Modes for different convolutions. It is currently not used. Default: 1.
stride (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
the depth, height and width of movement are both strides, or a tuple of three int numbers that
represent depth, height and width of movement respectively. Default: 1.

View File

@ -70,6 +70,7 @@ class Assign(Primitive):
@prim_attr_register
def __init__(self):
"""Initialize Assign."""
self.init_prim_io_names(inputs=['ref', 'value'], outputs=['output'])
self.add_prim_attr('side_effect_mem', True)
@ -111,6 +112,7 @@ class InplaceAssign(PrimitiveWithInfer):
@deprecated("1.3", "Assign", False)
@ prim_attr_register
def __init__(self):
"""Initialize InplaceAssign."""
self.init_prim_io_names(inputs=['x', 'y', 'z'], outputs=['output'])
def infer_shape(self, x, y, z):
@ -137,6 +139,7 @@ class Load(PrimitiveWithCheck):
@prim_attr_register
def __init__(self):
"""Initialize Load."""
self.init_prim_io_names(inputs=['ref', 'u'], outputs=['output'])
def check_dtype(self, variable):
@ -178,8 +181,9 @@ class BoundingBoxEncode(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0)):
validator.check_value_type('means', means, (tuple), self.name)
validator.check_value_type('stds', stds, (tuple), self.name)
"""Initialize BoundingBoxEncode."""
validator.check_value_type('means', means, tuple, self.name)
validator.check_value_type('stds', stds, tuple, self.name)
for i, value in enumerate(means):
validator.check_value_type("means[%d]" % i, value, [float], self.name)
for i, value in enumerate(stds):
@ -241,8 +245,9 @@ class BoundingBoxDecode(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, max_shape, means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0), wh_ratio_clip=0.016):
validator.check_value_type('means', means, (tuple), self.name)
validator.check_value_type('stds', stds, (tuple), self.name)
"""Initialize BoundingBoxDecode."""
validator.check_value_type('means', means, tuple, self.name)
validator.check_value_type('stds', stds, tuple, self.name)
for i, value in enumerate(means):
validator.check_value_type("means[%d]" % i, value, [float], self.name)
for i, value in enumerate(stds):
@ -313,6 +318,7 @@ class CheckValid(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
"""Initialize CheckValid."""
self.init_prim_io_names(inputs=['bboxes', 'img_metas'], outputs=['output'])
def infer_shape(self, bboxes_shape, metas_shape):
@ -372,6 +378,7 @@ class IOU(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, mode='iou'):
"""Initialize IOU."""
if mode not in {'iou', 'iof'}:
raise KeyError("Mode only support 'iou' or 'iof'.")
self.init_prim_io_names(inputs=['anchor_boxes', 'gt_boxes'], outputs=['overlap'])
@ -407,6 +414,7 @@ class Partial(Primitive):
@prim_attr_register
def __init__(self):
"""Initialize Partial."""
self.add_prim_attr('side_effect_propagate', 1)
def __call__(self, *args):
@ -473,6 +481,7 @@ class Depend(Primitive):
@prim_attr_register
def __init__(self):
"""Initialize Depend."""
self.add_prim_attr('side_effect_propagate', 1)
def __call__(self, value, expr):
@ -603,6 +612,7 @@ class ConfusionMatrix(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, num_classes, dtype="int32"):
"""Initialize ConfusionMatrix."""
validator.check_value_type("num_classes", num_classes, [int], self.name)
validator.check_value_type("dtype", dtype, [str], self.name)
@ -781,6 +791,7 @@ class identity(Primitive):
@prim_attr_register
def __init__(self):
"""Initialize identity."""
self.add_prim_attr('side_effect_propagate', 1)
def __call__(self, x):

View File

@ -336,7 +336,7 @@ class UniformInt(PrimitiveWithInfer):
return out
class UniformReal(PrimitiveWithInfer):
class UniformReal(StandardNormal):
r"""
Produces random floating-point values i, uniformly distributed to the interval [0, 1).
@ -367,27 +367,6 @@ class UniformReal(PrimitiveWithInfer):
(2, 2)
"""
@prim_attr_register
def __init__(self, seed=0, seed2=0):
"""Initialize UniformReal"""
self.init_prim_io_names(inputs=['shape'], outputs=['output'])
self.add_prim_attr('side_effect_mem', True)
Validator.check_non_negative_int(seed, "seed", self.name)
Validator.check_non_negative_int(seed2, "seed2", self.name)
def __infer__(self, shape):
shape_v = shape["value"]
if shape_v is None:
raise ValueError(f"For {self.name}, shape must be const.")
Validator.check_value_type("shape", shape_v, [tuple], self.name)
for i, shape_i in enumerate(shape_v):
Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
out = {
'shape': shape_v,
'dtype': mstype.float32,
'value': None}
return out
class RandomChoiceWithMask(PrimitiveWithInfer):
"""
@ -445,11 +424,11 @@ class RandomChoiceWithMask(PrimitiveWithInfer):
def infer_shape(self, x_shape):
Validator.check_int(len(x_shape), 1, Rel.GE, "input_x rank", self.name)
Validator.check_int(len(x_shape), 5, Rel.LE, "input_x rank", self.name)
return ([self.count, len(x_shape)], [self.count])
return [self.count, len(x_shape)], [self.count]
def infer_dtype(self, x_dtype):
Validator.check_tensor_dtype_valid('x', x_dtype, [mstype.bool_], self.name)
return (mstype.int32, mstype.bool_)
return mstype.int32, mstype.bool_
class RandomCategorical(PrimitiveWithInfer):
@ -539,12 +518,12 @@ class Multinomial(PrimitiveWithInfer):
seed2 (int): Random seed2, must be non-negative. Default: 0.
Inputs:
- **input** (Tensor[float32]) - the input tensor containing the cumsum of probabilities, must be 1 or 2
- **x** (Tensor[float32]) - the input tensor containing the cumsum of probabilities, must be 1 or 2
dimensions.
- **num_samples** (int32) - number of samples to draw.
Outputs:
Tensor with the same rows as input, each row has num_samples sampled indices.
Tensor with the same rows as `x`, each row has num_samples sampled indices.
Raises:
TypeError: If neither `seed` nor `seed2` is an int.
@ -555,16 +534,16 @@ class Multinomial(PrimitiveWithInfer):
``GPU``
Examples:
>>> input = Tensor([0., 9., 4., 0.], mstype.float32)
>>> x = Tensor([0., 9., 4., 0.], mstype.float32)
>>> multinomial = ops.Multinomial(seed=10)
>>> output = multinomial(input, 2)
>>> output = multinomial(x, 2)
>>> print(output)
[2 1]
"""
@prim_attr_register
def __init__(self, seed=0, seed2=0):
"""init"""
"""Initialize Multinomial."""
Validator.check_non_negative_int(seed, "seed", self.name)
Validator.check_non_negative_int(seed2, "seed2", self.name)
self.init_prim_io_names(inputs=['input', 'num_sample'], outputs=['output'])
@ -655,11 +634,11 @@ class UniformCandidateSampler(PrimitiveWithInfer):
Validator.check_subclass("true_classes_type", true_classes_type, mstype.tensor, self.name)
Validator.check_tensor_dtype_valid("true_classes_type", true_classes_type,
(mstype.int32, mstype.int64), self.name)
return (true_classes_type, mstype.float32, mstype.float32)
return true_classes_type, mstype.float32, mstype.float32
def infer_shape(self, true_classes_shape):
Validator.check("true_class.shape[1]", true_classes_shape[1], "num_true", self.num_true, Rel.EQ, self.name)
return ([self.num_sampled], true_classes_shape, [self.num_sampled])
return [self.num_sampled], true_classes_shape, [self.num_sampled]
class LogUniformCandidateSampler(PrimitiveWithInfer):
@ -675,7 +654,7 @@ class LogUniformCandidateSampler(PrimitiveWithInfer):
all sampled classes in a batch are unique. Default: True.
range_max (int): The number of possible classes. When `unique` is True,
`range_max` must be greater than or equal to `num_sampled`. Default: 5.
seed (int): Random seed, must be non-negative.
seed (int): Random seed, must be non-negative. Default: 0.
Inputs:
- **true_classes** (Tensor) - The target classes. With data type of int64 and shape [batch_size, num_true].

View File

@ -59,7 +59,7 @@ class SparseToDense(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
"""Initialize index_select"""
"""Initialize SparseToDense."""
self.init_prim_io_names(inputs=['indices', 'values', 'dense_shape'], outputs=['output'])
def __infer__(self, indices, values, sparse_shape):

View File

@ -58,8 +58,9 @@ class BondForce(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, bond_numbers, atom_numbers):
validator.check_value_type('bond_numbers', bond_numbers, (int), self.name)
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
"""Initialize BondForce."""
validator.check_value_type('bond_numbers', bond_numbers, int, self.name)
validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
self.bond_numbers = bond_numbers
self.atom_numbers = atom_numbers
self.add_prim_attr('bond_numbers', self.bond_numbers)
@ -131,8 +132,9 @@ class BondEnergy(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, bond_numbers, atom_numbers):
validator.check_value_type('bond_numbers', bond_numbers, (int), self.name)
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
"""Initialize BondEnergy."""
validator.check_value_type('bond_numbers', bond_numbers, int, self.name)
validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
self.bond_numbers = bond_numbers
self.atom_numbers = atom_numbers
self.add_prim_attr('bond_numbers', self.bond_numbers)
@ -199,8 +201,9 @@ class BondAtomEnergy(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, bond_numbers, atom_numbers):
validator.check_value_type('bond_numbers', bond_numbers, (int), self.name)
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
"""Initialize BondAtomEnergy."""
validator.check_value_type('bond_numbers', bond_numbers, int, self.name)
validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
self.bond_numbers = bond_numbers
self.atom_numbers = atom_numbers
self.add_prim_attr('bond_numbers', self.bond_numbers)
@ -266,8 +269,9 @@ class BondForceWithAtomEnergy(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, bond_numbers, atom_numbers):
validator.check_value_type('bond_numbers', bond_numbers, (int), self.name)
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
"""Initialize BondForceWithAtomEnergy."""
validator.check_value_type('bond_numbers', bond_numbers, int, self.name)
validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
self.bond_numbers = bond_numbers
self.atom_numbers = atom_numbers
self.add_prim_attr('bond_numbers', self.bond_numbers)
@ -346,8 +350,9 @@ class BondForceWithAtomVirial(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, bond_numbers, atom_numbers):
validator.check_value_type('bond_numbers', bond_numbers, (int), self.name)
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
"""Initialize BondForceWithAtomVirial."""
validator.check_value_type('bond_numbers', bond_numbers, int, self.name)
validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
self.bond_numbers = bond_numbers
self.atom_numbers = atom_numbers
self.add_prim_attr('bond_numbers', self.bond_numbers)
@ -457,7 +462,8 @@ class DihedralForce(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, dihedral_numbers):
validator.check_value_type('dihedral_numbers', dihedral_numbers, (int), self.name)
"""Initialize DihedralForce."""
validator.check_value_type('dihedral_numbers', dihedral_numbers, int, self.name)
self.dihedral_numbers = dihedral_numbers
self.init_prim_io_names(inputs=['uint_crd_f', 'scaler_f', 'atom_a', 'atom_b', 'atom_c', 'atom_d', 'ipn', 'pk',
'gamc', 'gams', 'pn'],
@ -549,7 +555,8 @@ class DihedralEnergy(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, dihedral_numbers):
validator.check_value_type('dihedral_numbers', dihedral_numbers, (int), self.name)
"""Initialize DihedralEnergy."""
validator.check_value_type('dihedral_numbers', dihedral_numbers, int, self.name)
self.dihedral_numbers = dihedral_numbers
self.init_prim_io_names(inputs=['uint_crd_f', 'scaler_f', 'atom_a', 'atom_b', 'atom_c', 'atom_d', 'ipn', 'pk',
'gamc', 'gams', 'pn'],
@ -639,7 +646,8 @@ class DihedralAtomEnergy(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, dihedral_numbers):
validator.check_value_type('dihedral_numbers', dihedral_numbers, (int), self.name)
"""Initialize DihedralAtomEnergy."""
validator.check_value_type('dihedral_numbers', dihedral_numbers, int, self.name)
self.dihedral_numbers = dihedral_numbers
self.init_prim_io_names(inputs=['uint_crd_f', 'scaler_f', 'atom_a', 'atom_b', 'atom_c', 'atom_d', 'ipn', 'pk',
'gamc', 'gams', 'pn'],
@ -729,7 +737,8 @@ class DihedralForceWithAtomEnergy(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, dihedral_numbers):
validator.check_value_type('dihedral_numbers', dihedral_numbers, (int), self.name)
"""Initialize DihedralForceWithAtomEnergy."""
validator.check_value_type('dihedral_numbers', dihedral_numbers, int, self.name)
self.dihedral_numbers = dihedral_numbers
self.init_prim_io_names(inputs=['uint_crd_f', 'scaler_f', 'atom_a', 'atom_b', 'atom_c', 'atom_d', 'ipn', 'pk',
'gamc', 'gams', 'pn'],
@ -826,7 +835,8 @@ class AngleForce(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, angle_numbers):
validator.check_value_type('angle_numbers', angle_numbers, (int), self.name)
"""Initialize AngleForce."""
validator.check_value_type('angle_numbers', angle_numbers, int, self.name)
self.angle_numbers = angle_numbers
self.init_prim_io_names(inputs=['uint_crd_f', 'scaler_f', 'atom_a', 'atom_b', 'atom_c', 'angle_k',
'angle_theta0'],
@ -902,7 +912,8 @@ class AngleEnergy(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, angle_numbers):
validator.check_value_type('angle_numbers', angle_numbers, (int), self.name)
"""Initialize AngleEnergy."""
validator.check_value_type('angle_numbers', angle_numbers, int, self.name)
self.angle_numbers = angle_numbers
self.init_prim_io_names(inputs=['uint_crd_f', 'scaler_f', 'atom_a', 'atom_b', 'atom_c', 'angle_k',
'angle_theta0'],
@ -972,7 +983,8 @@ class AngleAtomEnergy(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, angle_numbers):
validator.check_value_type('angle_numbers', angle_numbers, (int), self.name)
"""Initialize AngleAtomEnergy."""
validator.check_value_type('angle_numbers', angle_numbers, int, self.name)
self.angle_numbers = angle_numbers
self.init_prim_io_names(inputs=['uint_crd_f', 'scaler_f', 'atom_a', 'atom_b', 'atom_c', 'angle_k',
'angle_theta0'],
@ -1043,7 +1055,8 @@ class AngleForceWithAtomEnergy(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, angle_numbers):
validator.check_value_type('angle_numbers', angle_numbers, (int), self.name)
"""Initialize AngleForceWithAtomEnergy."""
validator.check_value_type('angle_numbers', angle_numbers, int, self.name)
self.angle_numbers = angle_numbers
self.init_prim_io_names(inputs=['uint_crd_f', 'scaler_f', 'atom_a', 'atom_b', 'atom_c', 'angle_k',
'angle_theta0'],
@ -1126,8 +1139,9 @@ class Dihedral14LJForce(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, nb14_numbers, atom_numbers):
validator.check_value_type('nb14_numbers', nb14_numbers, (int), self.name)
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
"""Initialize Dihedral14LJForce."""
validator.check_value_type('nb14_numbers', nb14_numbers, int, self.name)
validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
self.dihedral_14_numbers = nb14_numbers
self.atom_numbers = atom_numbers
self.init_prim_io_names(
@ -1217,8 +1231,9 @@ class Dihedral14LJEnergy(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, nb14_numbers, atom_numbers):
validator.check_value_type('nb14_numbers', nb14_numbers, (int), self.name)
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
"""Initialize Dihedral14LJEnergy"""
validator.check_value_type('nb14_numbers', nb14_numbers, int, self.name)
validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
self.dihedral_14_numbers = nb14_numbers
self.atom_numbers = atom_numbers
@ -1312,8 +1327,9 @@ class Dihedral14LJForceWithDirectCF(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, nb14_numbers, atom_numbers):
validator.check_value_type('nb14_numbers', nb14_numbers, (int), self.name)
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
"""Initialize Dihedral14LJForceWithDirectCF."""
validator.check_value_type('nb14_numbers', nb14_numbers, int, self.name)
validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
self.dihedral_14_numbers = nb14_numbers
self.atom_numbers = atom_numbers
@ -1409,8 +1425,9 @@ class Dihedral14LJCFForceWithAtomEnergy(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, nb14_numbers, atom_numbers):
validator.check_value_type('nb14_numbers', nb14_numbers, (int), self.name)
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
"""Initialize Dihedral14LJCFForceWithAtomEnergy."""
validator.check_value_type('nb14_numbers', nb14_numbers, int, self.name)
validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
self.dihedral_14_numbers = nb14_numbers
self.atom_numbers = atom_numbers
@ -1500,8 +1517,9 @@ class Dihedral14LJAtomEnergy(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, nb14_numbers, atom_numbers):
validator.check_value_type('nb14_numbers', nb14_numbers, (int), self.name)
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
"""Initialize Dihedral14LJAtomEnergy."""
validator.check_value_type('nb14_numbers', nb14_numbers, int, self.name)
validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
self.dihedral_14_numbers = nb14_numbers
self.atom_numbers = atom_numbers
@ -1589,8 +1607,9 @@ class Dihedral14CFEnergy(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, nb14_numbers, atom_numbers):
validator.check_value_type('nb14_numbers', nb14_numbers, (int), self.name)
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
"""Initialize Dihedral14CFEnergy."""
validator.check_value_type('nb14_numbers', nb14_numbers, int, self.name)
validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
self.dihedral_14_numbers = nb14_numbers
self.atom_numbers = atom_numbers
@ -1668,8 +1687,9 @@ class Dihedral14CFAtomEnergy(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, nb14_numbers, atom_numbers):
validator.check_value_type('nb14_numbers', nb14_numbers, (int), self.name)
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
"""Initialize Dihedral14CFAtomEnergy."""
validator.check_value_type('nb14_numbers', nb14_numbers, int, self.name)
validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
self.dihedral_14_numbers = nb14_numbers
self.atom_numbers = atom_numbers
@ -1755,13 +1775,14 @@ class MDIterationLeapFrog(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, float4_numbers, atom_numbers, half_dt, dt, exp_gamma, is_max_velocity, max_velocity):
validator.check_value_type('float4_numbers', float4_numbers, (int), self.name)
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
validator.check_value_type('half_dt', half_dt, (float), self.name)
validator.check_value_type('dt', dt, (float), self.name)
validator.check_value_type('exp_gamma', exp_gamma, (float), self.name)
validator.check_value_type('is_max_velocity', is_max_velocity, (int), self.name)
validator.check_value_type('max_velocity', max_velocity, (float), self.name)
"""Initialize MDIterationLeapFrog."""
validator.check_value_type('float4_numbers', float4_numbers, int, self.name)
validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
validator.check_value_type('half_dt', half_dt, float, self.name)
validator.check_value_type('dt', dt, float, self.name)
validator.check_value_type('exp_gamma', exp_gamma, float, self.name)
validator.check_value_type('is_max_velocity', is_max_velocity, int, self.name)
validator.check_value_type('max_velocity', max_velocity, float, self.name)
self.float4_numbers = float4_numbers
self.atom_numbers = atom_numbers
self.half_dt = half_dt
@ -1828,14 +1849,15 @@ class PMEReciprocalForce(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, atom_numbers, beta, fftx, ffty, fftz, box_length_0, box_length_1, box_length_2):
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
validator.check_value_type('beta', beta, (float), self.name)
validator.check_value_type('fftx', fftx, (int), self.name)
validator.check_value_type('ffty', ffty, (int), self.name)
validator.check_value_type('fftz', fftz, (int), self.name)
validator.check_value_type('box_length_0', box_length_0, (float), self.name)
validator.check_value_type('box_length_1', box_length_1, (float), self.name)
validator.check_value_type('box_length_2', box_length_2, (float), self.name)
"""Initialize PMEReciprocalForce."""
validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
validator.check_value_type('beta', beta, float, self.name)
validator.check_value_type('fftx', fftx, int, self.name)
validator.check_value_type('ffty', ffty, int, self.name)
validator.check_value_type('fftz', fftz, int, self.name)
validator.check_value_type('box_length_0', box_length_0, float, self.name)
validator.check_value_type('box_length_1', box_length_1, float, self.name)
validator.check_value_type('box_length_2', box_length_2, float, self.name)
self.atom_numbers = atom_numbers
self.beta = beta
self.fftx = fftx
@ -1906,9 +1928,10 @@ class PMEExcludedForce(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, atom_numbers, excluded_numbers, beta):
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
validator.check_value_type('excluded_numbers', excluded_numbers, (int), self.name)
validator.check_value_type('beta', beta, (float), self.name)
"""Initialize PMEExcludedForce."""
validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
validator.check_value_type('excluded_numbers', excluded_numbers, int, self.name)
validator.check_value_type('beta', beta, float, self.name)
self.atom_numbers = atom_numbers
self.excluded_numbers = excluded_numbers
self.beta = beta
@ -1999,15 +2022,16 @@ class PMEEnergy(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, atom_numbers, excluded_numbers, beta, fftx, ffty, fftz, box_length_0, box_length_1,
box_length_2):
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
validator.check_value_type('excluded_numbers', excluded_numbers, (int), self.name)
validator.check_value_type('beta', beta, (float), self.name)
validator.check_value_type('fftx', fftx, (int), self.name)
validator.check_value_type('ffty', ffty, (int), self.name)
validator.check_value_type('fftz', fftz, (int), self.name)
validator.check_value_type('box_length_0', box_length_0, (float), self.name)
validator.check_value_type('box_length_1', box_length_1, (float), self.name)
validator.check_value_type('box_length_2', box_length_2, (float), self.name)
"""Initialize PMEEnergy."""
validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
validator.check_value_type('excluded_numbers', excluded_numbers, int, self.name)
validator.check_value_type('beta', beta, float, self.name)
validator.check_value_type('fftx', fftx, int, self.name)
validator.check_value_type('ffty', ffty, int, self.name)
validator.check_value_type('fftz', fftz, int, self.name)
validator.check_value_type('box_length_0', box_length_0, float, self.name)
validator.check_value_type('box_length_1', box_length_1, float, self.name)
validator.check_value_type('box_length_2', box_length_2, float, self.name)
self.atom_numbers = atom_numbers
self.excluded_numbers = excluded_numbers
self.beta = beta
@ -2113,8 +2137,9 @@ class LJEnergy(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, atom_numbers, cutoff_square):
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
validator.check_value_type('cutoff_square', cutoff_square, (float), self.name)
"""Initialize LJEnergy."""
validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
validator.check_value_type('cutoff_square', cutoff_square, float, self.name)
self.atom_numbers = atom_numbers
self.cutoff_square = cutoff_square
self.init_prim_io_names(
@ -2199,8 +2224,9 @@ class LJForce(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, atom_numbers, cutoff_square):
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
validator.check_value_type('cutoff_square', cutoff_square, (float), self.name)
"""Initialize LJForce."""
validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
validator.check_value_type('cutoff_square', cutoff_square, float, self.name)
self.atom_numbers = atom_numbers
self.cutoff_square = cutoff_square
self.init_prim_io_names(
@ -2280,9 +2306,10 @@ class LJForceWithPMEDirectForce(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, atom_numbers, cutoff, pme_beta):
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
validator.check_value_type('cutoff', cutoff, (float), self.name)
validator.check_value_type('pme_beta', pme_beta, (float), self.name)
"""Initialize LJForceWithPMEDirectForce."""
validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
validator.check_value_type('cutoff', cutoff, float, self.name)
validator.check_value_type('pme_beta', pme_beta, float, self.name)
self.atom_numbers = atom_numbers
self.cutoff = cutoff
self.pme_beta = pme_beta
@ -2340,8 +2367,9 @@ class GetCenterOfGeometry(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, center_numbers, center_numbers_inverse):
validator.check_value_type('center_numbers', center_numbers, (int), self.name)
validator.check_value_type('center_numbers_inverse', center_numbers_inverse, (float), self.name)
"""Initialize GetCenterOfGeometry."""
validator.check_value_type('center_numbers', center_numbers, int, self.name)
validator.check_value_type('center_numbers_inverse', center_numbers_inverse, float, self.name)
self.center_numbers = center_numbers
self.center_numbers_inverse = center_numbers_inverse
self.add_prim_attr('center_numbers', self.center_numbers)
@ -2377,8 +2405,9 @@ class MDTemperature(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, residue_numbers, atom_numbers):
validator.check_value_type('residue_numbers', residue_numbers, (int), self.name)
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
"""Initialize MDTemperature."""
validator.check_value_type('residue_numbers', residue_numbers, int, self.name)
validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
self.residue_numbers = residue_numbers
self.atom_numbers = atom_numbers
self.add_prim_attr('residue_numbers', self.residue_numbers)
@ -2423,8 +2452,8 @@ class NeighborListUpdate(PrimitiveWithInfer):
list first time or not.
nxy(int32): the total number of grids divided in xy plane.
excluded_atom_numbers(int32): the total atom numbers in the excluded list.
cutoff(float32): the cutoff distance for short-range force calculation.
skin(float32): the overflow value of cutoff to maintain a neighbor list.
cutoff(float32): the cutoff distance for short-range force calculation. Default: 10.0.
skin(float32): the overflow value of cutoff to maintain a neighbor list. Default: 2.0.
cutoff_square(float32): the suqare value of cutoff.
half_skin_square(float32): skin*skin/4, indicates the maximum
square value of the distance atom allowed to move between two updates.
@ -2432,8 +2461,9 @@ class NeighborListUpdate(PrimitiveWithInfer):
radius of the neighbor list for each atom.
half_cutoff_with_skin(float32): cutoff_with_skin/2.
cutoff_with_skin_square(float32): the square value of cutoff_with_skin.
refresh_interval(int32): the number of iteration steps between two updates of neighbor list.
max_atom_in_grid_numbers(int32): the maximum number of atoms in one grid.
refresh_interval(int32): the number of iteration steps between two updates of neighbor list. Default: 20.
max_atom_in_grid_numbers(int32): the maximum number of atoms in one grid. Default: 64.
max_neighbor_numbers(int32): The maximum number of neighbors. Default: 800.
Inputs:
- **atom_numbers_in_grid_bucket** (Tensor, int32) - [G,], the number of atoms in each grid bucket.
@ -2470,6 +2500,7 @@ class NeighborListUpdate(PrimitiveWithInfer):
def __init__(self, grid_numbers, atom_numbers, not_first_time, nxy, excluded_atom_numbers,
cutoff_square, half_skin_square, cutoff_with_skin, half_cutoff_with_skin, cutoff_with_skin_square,
refresh_interval=20, cutoff=10.0, skin=2.0, max_atom_in_grid_numbers=64, max_neighbor_numbers=800):
"""Initialize NeighborListUpdate."""
self.grid_numbers = grid_numbers
self.atom_numbers = atom_numbers
self.refresh_interval = refresh_interval
@ -2641,13 +2672,14 @@ class MDIterationLeapFrogWithRF(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, float4_numbers, atom_numbers, half_dt, dt, exp_gamma, is_max_velocity, max_velocity):
validator.check_value_type('float4_numbers', float4_numbers, (int), self.name)
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
validator.check_value_type('half_dt', half_dt, (float), self.name)
validator.check_value_type('dt', dt, (float), self.name)
validator.check_value_type('exp_gamma', exp_gamma, (float), self.name)
validator.check_value_type('is_max_velocity', is_max_velocity, (int), self.name)
validator.check_value_type('max_velocity', max_velocity, (float), self.name)
"""Initialize MDIterationLeapFrogWithRF."""
validator.check_value_type('float4_numbers', float4_numbers, int, self.name)
validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
validator.check_value_type('half_dt', half_dt, float, self.name)
validator.check_value_type('dt', dt, float, self.name)
validator.check_value_type('exp_gamma', exp_gamma, float, self.name)
validator.check_value_type('is_max_velocity', is_max_velocity, int, self.name)
validator.check_value_type('max_velocity', max_velocity, float, self.name)
self.float4_numbers = float4_numbers
self.atom_numbers = atom_numbers
self.half_dt = half_dt
@ -2740,6 +2772,7 @@ class MDIterationLeapFrogLiujian(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, atom_numbers, half_dt, dt, exp_gamma):
"""Initialize MDIterationLeapFrogLiujian."""
self.atom_numbers = atom_numbers
self.half_dt = half_dt
self.dt = dt
@ -2792,6 +2825,7 @@ class CrdToUintCrd(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, atom_numbers):
"""Initialize CrdToUintCrd."""
self.atom_numbers = atom_numbers
self.add_prim_attr('atom_numbers', self.atom_numbers)
self.init_prim_io_names(
@ -2827,6 +2861,7 @@ class MDIterationSetupRandState(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, atom_numbers, seed):
"""Initialize MDIterationSetupRandState."""
self.atom_numbers = atom_numbers
self.seed = seed
self.add_prim_attr('atom_numbers', self.atom_numbers)
@ -2873,10 +2908,11 @@ class TransferCrd(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, start_serial, end_serial, number, atom_numbers):
validator.check_value_type('start_serial', start_serial, (int), self.name)
validator.check_value_type('end_serial', end_serial, (int), self.name)
validator.check_value_type('number', number, (int), self.name)
validator.check_value_type('atom_numbers', atom_numbers, (int), self.name)
"""Initialize TransferCrd."""
validator.check_value_type('start_serial', start_serial, int, self.name)
validator.check_value_type('end_serial', end_serial, int, self.name)
validator.check_value_type('number', number, int, self.name)
validator.check_value_type('atom_numbers', atom_numbers, int, self.name)
self.start_serial = start_serial
self.end_serial = end_serial
self.number = number

View File

@ -29,6 +29,7 @@ from mindspore.ops import operations as P
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
from mindspore.ops.primitive import constexpr, PrimitiveWithInfer, prim_attr_register
from mindspore.ops._grad.grad_base import bprop_getters
from mindspore.ops._utils.utils import generate_shape_index
from mindspore import Tensor, RowTensor, context
from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common import dtype as mstype
@ -89,17 +90,6 @@ class MindDataSet(MindData):
return tuple(lst)
@constexpr
def _generate_shape_index(out_shape, indices_shape, axis):
out_rank = len(out_shape)
ind_rank = len(indices_shape)
if axis < 0:
axis += out_rank - ind_rank + 1
perm_part1 = tuple(range(axis, axis + ind_rank))
index = tuple(range(out_rank))
perm = perm_part1 + index[:axis] + index[axis + ind_rank:]
return perm
@constexpr
def _generate_inverse_index(x_shape, axis):
x_rank = len(x_shape)
@ -155,7 +145,7 @@ def get_bprop_sparse_gather_v2(self):
out_shp = shape_op(dout)
ind_shp = shape_op(indices)
# Example: out_shape:(3,2,3) axis 1 -> (1,0,2)
perm_1 = _generate_shape_index(out_shp, ind_shp, axis)
perm_1 = generate_shape_index(out_shp, ind_shp, axis)
values_transpose = transpose(dout, perm_1)
params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis])
# Example: out_shape:(3,2,3) axis 2 -> (1,2,0)