!47529 delete constexpr note

Merge pull request !47529 from NaCN/constexpr_remove
This commit is contained in:
i-robot 2023-01-05 11:46:29 +00:00 committed by Gitee
commit 7ccb517428
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 22 additions and 55 deletions

View File

@ -52,7 +52,7 @@ def _cal_repeat_dims(x_rank, rep, expand_axis):
return tuple(rep_dims)
#remove constexpr
@constexpr
def _cal_reshape(x_shape, rep, axis):
x_reshape = list(x_shape)
x_reshape[axis] *= rep

View File

@ -13,7 +13,6 @@
# limitations under the License.
# ============================================================================
"""math Operations."""
import numpy as np
from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils
from mindspore.common import dtype as mstype
from mindspore._checkparam import Validator as validator
@ -24,7 +23,6 @@ from mindspore.ops.function.math_func import cummin as cummin_
from mindspore.ops import operations as P
#TODO: remove comment
@constexpr
def _check_validate_axis(axis, name):
if isinstance(axis, (tuple, list)):
@ -34,14 +32,12 @@ def _check_validate_axis(axis, name):
return axis
#TODO: remove comment
@constexpr
def _check_validate_keepdims(keep_dims, name):
keep_dims = validator.check_value_type('keep_dims', keep_dims, [bool], name)
return keep_dims
#TODO: remove comment
@constexpr
def is_const(x):
return x is not None
@ -119,7 +115,6 @@ def count_nonzero(x, axis=(), keep_dims=False, dtype=mstype.int32):
return nonzero_num
#TODO: remove comment
@constexpr
def _int_to_tuple_conv(axes):
"""
@ -131,7 +126,6 @@ def _int_to_tuple_conv(axes):
return axes
#TODO: remove comment
@constexpr
def _check_axes(axes, prim_name=None):
"""
@ -152,7 +146,6 @@ def _check_axes(axes, prim_name=None):
return axes
#TODO: remove comment
@constexpr
def _typecheck_input(x1_type, x2_type, prim_name=None):
"""
@ -166,7 +159,7 @@ def _typecheck_input(x1_type, x2_type, prim_name=None):
f"and x2_type: {x2_type}.")
#remove constexpr
@constexpr
def _axes_int_check(x1_shape, x2_shape, axes, prim_name=None):
"""
Convert from single int axes to 2d tuple if required
@ -182,7 +175,6 @@ def _axes_int_check(x1_shape, x2_shape, axes, prim_name=None):
return axes
#TODO: remove comment
@constexpr
def _calc_new_shape(shape, axes, position=0):
"""
@ -190,10 +182,14 @@ def _calc_new_shape(shape, axes, position=0):
'position' refers to whether tensor is first or second in the op.
"""
contraction_axes = tuple(i if i >= 0 else i + len(shape) for i in axes[position])
prod_contraction = int(np.prod([shape[i] for i in contraction_axes]))
prod_contraction = 1
for i in contraction_axes:
prod_contraction *= shape[i]
free_axes = tuple(i for i in range(len(shape)) if i not in contraction_axes)
free_dims = tuple(shape[i] if shape[i] is not None else -1 for i in free_axes)
prod_free = int(np.prod(free_dims))
prod_free = 1
for free_dim in free_dims:
prod_free *= free_dim
transpose_perm = contraction_axes + free_axes if position else free_axes + contraction_axes
new_shape = (prod_contraction, prod_free) if position else (prod_free, prod_contraction)
@ -272,7 +268,6 @@ def tensor_dot(x1, x2, axes):
return final_result
#TODO: remove comment
@constexpr
def _typecheck_input_dot(x1_type, x2_type, prim_name=None):
"""
@ -286,7 +281,7 @@ def _typecheck_input_dot(x1_type, x2_type, prim_name=None):
f"x1_type: {x1_type} and x2_type: {x2_type}.")
#remove constexpr
@constexpr
def _get_transpose_shape(x2_shape):
x2_shape_range = tuple(range(len(x2_shape)))
x2_shape_transpose = x2_shape_range[-2:-1] + x2_shape_range[:-2] + x2_shape_range[-1:]
@ -385,7 +380,7 @@ def dot(x1, x2):
return matmul_op(x1, x2)
#remove constexpr
@constexpr
def _get_batch_size(x1_shape, x2_shape, prim_name=None):
"""
Get batch sizes from two inputs
@ -393,7 +388,6 @@ def _get_batch_size(x1_shape, x2_shape, prim_name=None):
return x1_shape[0], x2_shape[0]
#TODO: remove comment
@constexpr
def _typecheck_input_batch_dot(x1_type, x2_type, prim_name=None):
"""
@ -407,7 +401,7 @@ def _typecheck_input_batch_dot(x1_type, x2_type, prim_name=None):
f"x2_type: {x2_type}.")
#remove constexpr
@constexpr
def _check_axes_for_batch_dot(x1_shape, x2_shape, axes, prim_name=None):
"""
Check whether axes are valid and cast axes from tuple to list
@ -442,10 +436,14 @@ def _calc_new_shape_batchdot(shape, axes, position=0):
"""
axis = axes[position]
contraction_axes = tuple([axis])
prod_contraction = int(np.prod([shape[i] for i in contraction_axes]))
prod_contraction = 1
for i in contraction_axes:
prod_contraction *= shape[i]
free_axes = tuple(i for i in range(1, len(shape)) if i not in contraction_axes)
free_dims = tuple(shape[i] for i in free_axes)
prod_free = int(np.prod(free_dims))
prod_free = 1
for free_dim in free_dims:
prod_free *= free_dim
transpose_perm = contraction_axes + free_axes if position else free_axes + contraction_axes
transpose_perm = tuple([0]) + transpose_perm
@ -454,7 +452,6 @@ def _calc_new_shape_batchdot(shape, axes, position=0):
return new_shape, transpose_perm, free_dims
#TODO: remove comment
@constexpr
def _check_batch_size(x1_batch_size, x2_batch_size, prim_name=None):
"""
@ -466,7 +463,6 @@ def _check_batch_size(x1_batch_size, x2_batch_size, prim_name=None):
f"'x1_batch_size': {x1_batch_size} and 'x2_batch_size': {x2_batch_size}.")
#TODO: remove comment
@constexpr
def _get_output_shape(batch_size, x1_ret, x2_ret):
"""

View File

@ -114,7 +114,6 @@ def get_x_shape(x_shape):
return (s,)
#TODO: remove comment
@constexpr
def _check_attr_dtype(param_name, input_dtype, allow_dtypes, cls_name):
validator.check_value_type(param_name, input_dtype, allow_dtypes, cls_name)
@ -1488,7 +1487,6 @@ def flatten(input_x):
return _flatten(input_x)
#TODO: remove comment
@constexpr
def _check_select_type_match(scalar, tensor_type, scalar_name, tensor_name):
if isinstance(scalar, int) and tensor_type != mstype.int32:
@ -1499,7 +1497,6 @@ def _check_select_type_match(scalar, tensor_type, scalar_name, tensor_name):
f"then the input[{tensor_name}] must be a Tensor of float32.")
#TODO: remove comment
@constexpr
def _check_select_type(is_cond_tensor, is_x_scalar, is_y_scalar, is_x_tensor, is_y_tensor):
if not is_cond_tensor:
@ -5517,7 +5514,6 @@ def expand(input_x, size):
return expand_op(input_x, size)
#TODO: remove comment
@constexpr
def _check_fold_param(param, param_name):
"""Check the parameters of fold op."""
@ -5579,7 +5575,6 @@ def fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1):
return fold_op(input, output_size)
#TODO: remove comment
@constexpr
def _check_unfold_params(param, param_name, param_size):
"""Check the parameters of unfold op."""
@ -5646,7 +5641,6 @@ def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
return unfold_op(input)
#TODO: remove comment
@constexpr
def _check_diagonal_axes(dim1, dim2, x_ndim):
"""Check the parameters of unfold op."""

View File

@ -76,7 +76,6 @@ from mindspore._c_expression import Tensor as Tensor_
import mindspore.ops.function as F
#TODO: remove comment
@constexpr
def _make_tensor(val, dtype):
"""Returns the tensor with value `val` and dtype `dtype`."""
@ -5421,7 +5420,7 @@ def frac(x):
#####################################
#remove constexpr
@constexpr
def _create_cummin_perm(axis, x_shape):
"""Insure axis is in [-len(x_shape),len(s_shape)-1]"""
len_axis = len(x_shape)
@ -6571,7 +6570,6 @@ def hann_window(window_length, periodic=True):
return Tensor(w[:-1]) if periodic else Tensor(w)
#TODO: remove comment
@constexpr
def _type_convert(force, obj):
"""
@ -7353,7 +7351,7 @@ def _min(*args):
return min(*args)
#remove constexpr
@constexpr
def _infer_shape_rem(shape1, shape2, ndim1, ndim2, transpose_b):
"""Infers the shape of the last two dimensions after performing matmul."""
shape_rem = []
@ -7368,7 +7366,7 @@ def _infer_shape_rem(shape1, shape2, ndim1, ndim2, transpose_b):
return tuple(shape_rem)
#remove constexpr
@constexpr
def _check_matmul_shapes(shape1, shape2, prim_name=None):
"""Checks shape1 and shape2 are valid to perform matmul, and returns output shape after broadcasting."""
shape_out = list()
@ -7383,7 +7381,7 @@ def _check_matmul_shapes(shape1, shape2, prim_name=None):
return tuple(shape_out)
#remove constexpr
@constexpr
def _tile_size(shape, out_shape, ndim):
"""Returns tile_size such that shape*tile_size = out_shape"""
size = [1] * ndim
@ -7393,7 +7391,7 @@ def _tile_size(shape, out_shape, ndim):
return tuple(size)
#remove constexpr
@constexpr
def _expand(x, ndim):
"""Expand x to ndim from axis, which can be 0 or -1."""
rank_op = _get_cache_prim(P.Rank)()

View File

@ -203,7 +203,6 @@ def adaptive_avg_pool3d(input_x, output_size):
return adaptive_avg_pool3d_(input_x)
#TODO: remove comment
@constexpr
def _check_avgpool_1d_type_and_int(kernel_size, stride, ceil_mode, count_include_pad):
"""Checks the type of avgpool1d input"""
@ -215,7 +214,6 @@ def _check_avgpool_1d_type_and_int(kernel_size, stride, ceil_mode, count_include
validator.check_int(stride, 1, Rel.GE, "stride", 'avg_pool1d')
#TODO: remove comment
@constexpr
def check_non_negative_int(arg_value, arg_name=None, prim_name=None):
"""Check argument is non-negative integer, which mean arg_value >= 0."""
@ -304,7 +302,6 @@ def avg_pool1d(input_x, kernel_size=1, stride=1, padding=0, ceil_mode=False, cou
return input_x
#TODO: remove comment
@constexpr
def _check_avgpool_2d_kernel_size(kernel_size):
"""check and calculate the avgpool2d kernel_size"""
@ -322,7 +319,6 @@ def _check_avgpool_2d_kernel_size(kernel_size):
return kernel_size
#TODO: remove comment
@constexpr
def _check_avgpool_2d_stride(stride):
"""check and calculate the avgpool2d stride"""
@ -340,7 +336,6 @@ def _check_avgpool_2d_stride(stride):
return stride
#TODO: remove comment
@constexpr
def _check_avgpool_2d_padding(padding):
"""check and calculate the avgpool2d padding"""
@ -358,7 +353,6 @@ def _check_avgpool_2d_padding(padding):
return padding
#TODO: remove comment
@constexpr
def _check_avg_pool2d_type_and_value(ceil_mode, count_include_pad, divisor_override):
"""check the type of avgpool2d input"""
@ -453,7 +447,6 @@ def avg_pool2d(input_x, kernel_size=1, stride=1, padding=0, ceil_mode=False, cou
return input_x
#TODO: remove comment
@constexpr
def _check_avg_pool3d_padding(padding):
"""Check the padding value in avg_pool3d op."""
@ -1333,7 +1326,6 @@ def fast_gelu(x):
return fast_gelu_(x)
#TODO: remove comment
@constexpr
def _check_float_range_inc_right(arg_value, lower_limit, upper_limit, arg_name=None, prim_name=None):
"""
@ -1656,7 +1648,6 @@ def hardshrink(x, lambd=0.5):
return hshrink_op(x)
#TODO: remove comment
@constexpr
def _check_axis_in_range(axis, ndim):
"""Checks axes are with the bounds of ndim"""
@ -1667,7 +1658,6 @@ def _check_axis_in_range(axis, ndim):
return axis % ndim
#TODO: remove comment
@constexpr
def _check_axis_valid(axes, ndim):
"""
@ -1694,7 +1684,6 @@ def _get_flip_end(ndim, shape, axes):
return tuple([-shape[i] - 1 if i in axes else shape[i] + 1 for i in range(ndim)])
#TODO: remove comment
@constexpr
def _get_flip_strides(ndim, axes):
"""Calculate the strides of flip"""
@ -1892,7 +1881,6 @@ def hardswish(x):
return hardswish_(x)
#TODO: remove comment
@constexpr
def _check_interpolate_inputs(input_dims, roi, scales, sizes, coordinate_transformation_mode, mode,
prim_name):
@ -2459,7 +2447,6 @@ def pdist(x, p=2.0):
return pdist_(x)
#TODO: remove comment
@constexpr
def _check_pad_inputs(padding):
"""check the input of pad"""
@ -3432,14 +3419,12 @@ def mish(x):
return mish_(x)
#TODO: remove comment
@constexpr
def _check_value_type(arg_name, arg_value, valid_types, prim_name=None):
"""Checks whether a value is instance of some types."""
return validator.check_value_type(arg_name, arg_value, valid_types, prim_name)
#TODO: remove comment
@constexpr(check=False)
def _check_is_tensor(param_name, input_data, cls_name):
"""Internal function, used to check whether the input data is Tensor."""
@ -3448,7 +3433,6 @@ def _check_is_tensor(param_name, input_data, cls_name):
f"but got '{ops.typeof(input_data)}'")
#TODO: remove comment
@constexpr
def _check_number_gt_value(arg_name, arg_value, value, cls_name):
"""Internal function, used to judge whether arg_value is greater than or equal to value."""
@ -3646,7 +3630,6 @@ def grid_sample(input_x, grid, interpolation_mode='bilinear', padding_mode='zero
return _grid_sampler_3d(input_x, grid)
#TODO: remove comment
@constexpr
def _check_ctc_loss_inputs(blank, reduction, zero_infinity, prim_name):
validator.check_value_type("blank", blank, [int], prim_name)
@ -3723,7 +3706,6 @@ def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, reducti
return (loss, log_alpha)
#TODO: remove comment
@constexpr
def _check_gaussian_nll_loss(full, eps, reduction):
validator.check_value_type('full', full, [bool], 'gaussian_nll_loss')
@ -4366,7 +4348,6 @@ def adaptive_avg_pool1d(input_x, output_size):
return input_x
#TODO: remove comment
@constexpr
def _check_adaptive_max_pool1d_output_size(output_size):
"""Check the output_size value in adaptive_max_pool1d op."""
@ -4739,7 +4720,6 @@ def conv3d(inputs, weight, pad_mode="valid", padding=0, stride=1, dilation=1, gr
return output
#TODO: remove comment
@constexpr
def _check_positive_int(arg_value, arg_name=None, prim_name=None):
validator.check_positive_int(arg_value, arg_name=arg_name, prim_name=prim_name)

View File

@ -73,7 +73,6 @@ def random_gamma(shape, alpha, seed=None):
return output
#TODO: remove comment
@constexpr(reuse_result=False)
def _get_seed(op_seed, kernel_name):
"""Get the graph-level seed."""