forked from mindspore-Ecosystem/mindspore
!47529 delete constexpr note
Merge pull request !47529 from NaCN/constexpr_remove
This commit is contained in:
commit
7ccb517428
|
@ -52,7 +52,7 @@ def _cal_repeat_dims(x_rank, rep, expand_axis):
|
|||
return tuple(rep_dims)
|
||||
|
||||
|
||||
#remove constexpr
|
||||
@constexpr
|
||||
def _cal_reshape(x_shape, rep, axis):
|
||||
x_reshape = list(x_shape)
|
||||
x_reshape[axis] *= rep
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""math Operations."""
|
||||
import numpy as np
|
||||
from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore._checkparam import Validator as validator
|
||||
|
@ -24,7 +23,6 @@ from mindspore.ops.function.math_func import cummin as cummin_
|
|||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def _check_validate_axis(axis, name):
|
||||
if isinstance(axis, (tuple, list)):
|
||||
|
@ -34,14 +32,12 @@ def _check_validate_axis(axis, name):
|
|||
return axis
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def _check_validate_keepdims(keep_dims, name):
|
||||
keep_dims = validator.check_value_type('keep_dims', keep_dims, [bool], name)
|
||||
return keep_dims
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def is_const(x):
|
||||
return x is not None
|
||||
|
@ -119,7 +115,6 @@ def count_nonzero(x, axis=(), keep_dims=False, dtype=mstype.int32):
|
|||
return nonzero_num
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def _int_to_tuple_conv(axes):
|
||||
"""
|
||||
|
@ -131,7 +126,6 @@ def _int_to_tuple_conv(axes):
|
|||
return axes
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def _check_axes(axes, prim_name=None):
|
||||
"""
|
||||
|
@ -152,7 +146,6 @@ def _check_axes(axes, prim_name=None):
|
|||
return axes
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def _typecheck_input(x1_type, x2_type, prim_name=None):
|
||||
"""
|
||||
|
@ -166,7 +159,7 @@ def _typecheck_input(x1_type, x2_type, prim_name=None):
|
|||
f"and x2_type: {x2_type}.")
|
||||
|
||||
|
||||
#remove constexpr
|
||||
@constexpr
|
||||
def _axes_int_check(x1_shape, x2_shape, axes, prim_name=None):
|
||||
"""
|
||||
Convert from single int axes to 2d tuple if required
|
||||
|
@ -182,7 +175,6 @@ def _axes_int_check(x1_shape, x2_shape, axes, prim_name=None):
|
|||
return axes
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def _calc_new_shape(shape, axes, position=0):
|
||||
"""
|
||||
|
@ -190,10 +182,14 @@ def _calc_new_shape(shape, axes, position=0):
|
|||
'position' refers to whether tensor is first or second in the op.
|
||||
"""
|
||||
contraction_axes = tuple(i if i >= 0 else i + len(shape) for i in axes[position])
|
||||
prod_contraction = int(np.prod([shape[i] for i in contraction_axes]))
|
||||
prod_contraction = 1
|
||||
for i in contraction_axes:
|
||||
prod_contraction *= shape[i]
|
||||
free_axes = tuple(i for i in range(len(shape)) if i not in contraction_axes)
|
||||
free_dims = tuple(shape[i] if shape[i] is not None else -1 for i in free_axes)
|
||||
prod_free = int(np.prod(free_dims))
|
||||
prod_free = 1
|
||||
for free_dim in free_dims:
|
||||
prod_free *= free_dim
|
||||
|
||||
transpose_perm = contraction_axes + free_axes if position else free_axes + contraction_axes
|
||||
new_shape = (prod_contraction, prod_free) if position else (prod_free, prod_contraction)
|
||||
|
@ -272,7 +268,6 @@ def tensor_dot(x1, x2, axes):
|
|||
return final_result
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def _typecheck_input_dot(x1_type, x2_type, prim_name=None):
|
||||
"""
|
||||
|
@ -286,7 +281,7 @@ def _typecheck_input_dot(x1_type, x2_type, prim_name=None):
|
|||
f"x1_type: {x1_type} and x2_type: {x2_type}.")
|
||||
|
||||
|
||||
#remove constexpr
|
||||
@constexpr
|
||||
def _get_transpose_shape(x2_shape):
|
||||
x2_shape_range = tuple(range(len(x2_shape)))
|
||||
x2_shape_transpose = x2_shape_range[-2:-1] + x2_shape_range[:-2] + x2_shape_range[-1:]
|
||||
|
@ -385,7 +380,7 @@ def dot(x1, x2):
|
|||
return matmul_op(x1, x2)
|
||||
|
||||
|
||||
#remove constexpr
|
||||
@constexpr
|
||||
def _get_batch_size(x1_shape, x2_shape, prim_name=None):
|
||||
"""
|
||||
Get batch sizes from two inputs
|
||||
|
@ -393,7 +388,6 @@ def _get_batch_size(x1_shape, x2_shape, prim_name=None):
|
|||
return x1_shape[0], x2_shape[0]
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def _typecheck_input_batch_dot(x1_type, x2_type, prim_name=None):
|
||||
"""
|
||||
|
@ -407,7 +401,7 @@ def _typecheck_input_batch_dot(x1_type, x2_type, prim_name=None):
|
|||
f"x2_type: {x2_type}.")
|
||||
|
||||
|
||||
#remove constexpr
|
||||
@constexpr
|
||||
def _check_axes_for_batch_dot(x1_shape, x2_shape, axes, prim_name=None):
|
||||
"""
|
||||
Check whether axes are valid and cast axes from tuple to list
|
||||
|
@ -442,10 +436,14 @@ def _calc_new_shape_batchdot(shape, axes, position=0):
|
|||
"""
|
||||
axis = axes[position]
|
||||
contraction_axes = tuple([axis])
|
||||
prod_contraction = int(np.prod([shape[i] for i in contraction_axes]))
|
||||
prod_contraction = 1
|
||||
for i in contraction_axes:
|
||||
prod_contraction *= shape[i]
|
||||
free_axes = tuple(i for i in range(1, len(shape)) if i not in contraction_axes)
|
||||
free_dims = tuple(shape[i] for i in free_axes)
|
||||
prod_free = int(np.prod(free_dims))
|
||||
prod_free = 1
|
||||
for free_dim in free_dims:
|
||||
prod_free *= free_dim
|
||||
|
||||
transpose_perm = contraction_axes + free_axes if position else free_axes + contraction_axes
|
||||
transpose_perm = tuple([0]) + transpose_perm
|
||||
|
@ -454,7 +452,6 @@ def _calc_new_shape_batchdot(shape, axes, position=0):
|
|||
return new_shape, transpose_perm, free_dims
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def _check_batch_size(x1_batch_size, x2_batch_size, prim_name=None):
|
||||
"""
|
||||
|
@ -466,7 +463,6 @@ def _check_batch_size(x1_batch_size, x2_batch_size, prim_name=None):
|
|||
f"'x1_batch_size': {x1_batch_size} and 'x2_batch_size': {x2_batch_size}.")
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def _get_output_shape(batch_size, x1_ret, x2_ret):
|
||||
"""
|
||||
|
|
|
@ -114,7 +114,6 @@ def get_x_shape(x_shape):
|
|||
return (s,)
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def _check_attr_dtype(param_name, input_dtype, allow_dtypes, cls_name):
|
||||
validator.check_value_type(param_name, input_dtype, allow_dtypes, cls_name)
|
||||
|
@ -1488,7 +1487,6 @@ def flatten(input_x):
|
|||
return _flatten(input_x)
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def _check_select_type_match(scalar, tensor_type, scalar_name, tensor_name):
|
||||
if isinstance(scalar, int) and tensor_type != mstype.int32:
|
||||
|
@ -1499,7 +1497,6 @@ def _check_select_type_match(scalar, tensor_type, scalar_name, tensor_name):
|
|||
f"then the input[{tensor_name}] must be a Tensor of float32.")
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def _check_select_type(is_cond_tensor, is_x_scalar, is_y_scalar, is_x_tensor, is_y_tensor):
|
||||
if not is_cond_tensor:
|
||||
|
@ -5517,7 +5514,6 @@ def expand(input_x, size):
|
|||
return expand_op(input_x, size)
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def _check_fold_param(param, param_name):
|
||||
"""Check the parameters of fold op."""
|
||||
|
@ -5579,7 +5575,6 @@ def fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1):
|
|||
return fold_op(input, output_size)
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def _check_unfold_params(param, param_name, param_size):
|
||||
"""Check the parameters of unfold op."""
|
||||
|
@ -5646,7 +5641,6 @@ def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
|
|||
return unfold_op(input)
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def _check_diagonal_axes(dim1, dim2, x_ndim):
|
||||
"""Check the parameters of unfold op."""
|
||||
|
|
|
@ -76,7 +76,6 @@ from mindspore._c_expression import Tensor as Tensor_
|
|||
import mindspore.ops.function as F
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def _make_tensor(val, dtype):
|
||||
"""Returns the tensor with value `val` and dtype `dtype`."""
|
||||
|
@ -5421,7 +5420,7 @@ def frac(x):
|
|||
#####################################
|
||||
|
||||
|
||||
#remove constexpr
|
||||
@constexpr
|
||||
def _create_cummin_perm(axis, x_shape):
|
||||
"""Insure axis is in [-len(x_shape),len(s_shape)-1]"""
|
||||
len_axis = len(x_shape)
|
||||
|
@ -6571,7 +6570,6 @@ def hann_window(window_length, periodic=True):
|
|||
return Tensor(w[:-1]) if periodic else Tensor(w)
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def _type_convert(force, obj):
|
||||
"""
|
||||
|
@ -7353,7 +7351,7 @@ def _min(*args):
|
|||
return min(*args)
|
||||
|
||||
|
||||
#remove constexpr
|
||||
@constexpr
|
||||
def _infer_shape_rem(shape1, shape2, ndim1, ndim2, transpose_b):
|
||||
"""Infers the shape of the last two dimensions after performing matmul."""
|
||||
shape_rem = []
|
||||
|
@ -7368,7 +7366,7 @@ def _infer_shape_rem(shape1, shape2, ndim1, ndim2, transpose_b):
|
|||
return tuple(shape_rem)
|
||||
|
||||
|
||||
#remove constexpr
|
||||
@constexpr
|
||||
def _check_matmul_shapes(shape1, shape2, prim_name=None):
|
||||
"""Checks shape1 and shape2 are valid to perform matmul, and returns output shape after broadcasting."""
|
||||
shape_out = list()
|
||||
|
@ -7383,7 +7381,7 @@ def _check_matmul_shapes(shape1, shape2, prim_name=None):
|
|||
return tuple(shape_out)
|
||||
|
||||
|
||||
#remove constexpr
|
||||
@constexpr
|
||||
def _tile_size(shape, out_shape, ndim):
|
||||
"""Returns tile_size such that shape*tile_size = out_shape"""
|
||||
size = [1] * ndim
|
||||
|
@ -7393,7 +7391,7 @@ def _tile_size(shape, out_shape, ndim):
|
|||
return tuple(size)
|
||||
|
||||
|
||||
#remove constexpr
|
||||
@constexpr
|
||||
def _expand(x, ndim):
|
||||
"""Expand x to ndim from axis, which can be 0 or -1."""
|
||||
rank_op = _get_cache_prim(P.Rank)()
|
||||
|
|
|
@ -203,7 +203,6 @@ def adaptive_avg_pool3d(input_x, output_size):
|
|||
return adaptive_avg_pool3d_(input_x)
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def _check_avgpool_1d_type_and_int(kernel_size, stride, ceil_mode, count_include_pad):
|
||||
"""Checks the type of avgpool1d input"""
|
||||
|
@ -215,7 +214,6 @@ def _check_avgpool_1d_type_and_int(kernel_size, stride, ceil_mode, count_include
|
|||
validator.check_int(stride, 1, Rel.GE, "stride", 'avg_pool1d')
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def check_non_negative_int(arg_value, arg_name=None, prim_name=None):
|
||||
"""Check argument is non-negative integer, which mean arg_value >= 0."""
|
||||
|
@ -304,7 +302,6 @@ def avg_pool1d(input_x, kernel_size=1, stride=1, padding=0, ceil_mode=False, cou
|
|||
return input_x
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def _check_avgpool_2d_kernel_size(kernel_size):
|
||||
"""check and calculate the avgpool2d kernel_size"""
|
||||
|
@ -322,7 +319,6 @@ def _check_avgpool_2d_kernel_size(kernel_size):
|
|||
return kernel_size
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def _check_avgpool_2d_stride(stride):
|
||||
"""check and calculate the avgpool2d stride"""
|
||||
|
@ -340,7 +336,6 @@ def _check_avgpool_2d_stride(stride):
|
|||
return stride
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def _check_avgpool_2d_padding(padding):
|
||||
"""check and calculate the avgpool2d padding"""
|
||||
|
@ -358,7 +353,6 @@ def _check_avgpool_2d_padding(padding):
|
|||
return padding
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def _check_avg_pool2d_type_and_value(ceil_mode, count_include_pad, divisor_override):
|
||||
"""check the type of avgpool2d input"""
|
||||
|
@ -453,7 +447,6 @@ def avg_pool2d(input_x, kernel_size=1, stride=1, padding=0, ceil_mode=False, cou
|
|||
return input_x
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def _check_avg_pool3d_padding(padding):
|
||||
"""Check the padding value in avg_pool3d op."""
|
||||
|
@ -1333,7 +1326,6 @@ def fast_gelu(x):
|
|||
return fast_gelu_(x)
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def _check_float_range_inc_right(arg_value, lower_limit, upper_limit, arg_name=None, prim_name=None):
|
||||
"""
|
||||
|
@ -1656,7 +1648,6 @@ def hardshrink(x, lambd=0.5):
|
|||
return hshrink_op(x)
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def _check_axis_in_range(axis, ndim):
|
||||
"""Checks axes are with the bounds of ndim"""
|
||||
|
@ -1667,7 +1658,6 @@ def _check_axis_in_range(axis, ndim):
|
|||
return axis % ndim
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def _check_axis_valid(axes, ndim):
|
||||
"""
|
||||
|
@ -1694,7 +1684,6 @@ def _get_flip_end(ndim, shape, axes):
|
|||
return tuple([-shape[i] - 1 if i in axes else shape[i] + 1 for i in range(ndim)])
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def _get_flip_strides(ndim, axes):
|
||||
"""Calculate the strides of flip"""
|
||||
|
@ -1892,7 +1881,6 @@ def hardswish(x):
|
|||
return hardswish_(x)
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def _check_interpolate_inputs(input_dims, roi, scales, sizes, coordinate_transformation_mode, mode,
|
||||
prim_name):
|
||||
|
@ -2459,7 +2447,6 @@ def pdist(x, p=2.0):
|
|||
return pdist_(x)
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def _check_pad_inputs(padding):
|
||||
"""check the input of pad"""
|
||||
|
@ -3432,14 +3419,12 @@ def mish(x):
|
|||
return mish_(x)
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def _check_value_type(arg_name, arg_value, valid_types, prim_name=None):
|
||||
"""Checks whether a value is instance of some types."""
|
||||
return validator.check_value_type(arg_name, arg_value, valid_types, prim_name)
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr(check=False)
|
||||
def _check_is_tensor(param_name, input_data, cls_name):
|
||||
"""Internal function, used to check whether the input data is Tensor."""
|
||||
|
@ -3448,7 +3433,6 @@ def _check_is_tensor(param_name, input_data, cls_name):
|
|||
f"but got '{ops.typeof(input_data)}'")
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def _check_number_gt_value(arg_name, arg_value, value, cls_name):
|
||||
"""Internal function, used to judge whether arg_value is greater than or equal to value."""
|
||||
|
@ -3646,7 +3630,6 @@ def grid_sample(input_x, grid, interpolation_mode='bilinear', padding_mode='zero
|
|||
return _grid_sampler_3d(input_x, grid)
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def _check_ctc_loss_inputs(blank, reduction, zero_infinity, prim_name):
|
||||
validator.check_value_type("blank", blank, [int], prim_name)
|
||||
|
@ -3723,7 +3706,6 @@ def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, reducti
|
|||
return (loss, log_alpha)
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def _check_gaussian_nll_loss(full, eps, reduction):
|
||||
validator.check_value_type('full', full, [bool], 'gaussian_nll_loss')
|
||||
|
@ -4366,7 +4348,6 @@ def adaptive_avg_pool1d(input_x, output_size):
|
|||
return input_x
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def _check_adaptive_max_pool1d_output_size(output_size):
|
||||
"""Check the output_size value in adaptive_max_pool1d op."""
|
||||
|
@ -4739,7 +4720,6 @@ def conv3d(inputs, weight, pad_mode="valid", padding=0, stride=1, dilation=1, gr
|
|||
return output
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def _check_positive_int(arg_value, arg_name=None, prim_name=None):
|
||||
validator.check_positive_int(arg_value, arg_name=arg_name, prim_name=prim_name)
|
||||
|
|
|
@ -73,7 +73,6 @@ def random_gamma(shape, alpha, seed=None):
|
|||
return output
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr(reuse_result=False)
|
||||
def _get_seed(op_seed, kernel_name):
|
||||
"""Get the graph-level seed."""
|
||||
|
|
Loading…
Reference in New Issue