!46764 rectify constexpr in vmap

Merge pull request !46764 from NaCN/constexpr_3
This commit is contained in:
i-robot 2023-01-18 07:24:52 +00:00 committed by Gitee
commit a23dc71bf0
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
8 changed files with 35 additions and 94 deletions

View File

@ -128,7 +128,8 @@
"mindspore/tests/ut/python/nn/test_cell_method_attribute.py" "no-self-argument"
"mindspore/tests/ut/python/onnx/test_onnx.py" "unused-variable"
"mindspore/tests/ut/python/ops" "super-init-not-called"
"mindspore/tests/ut/python/ops/test_tensor_slice.py" "redefined-outer-name"
"mindspore/tests/st/ops/test_tensor_slice.py" "redefined-outer-name"
"mindspore/tests/st/ops/test_tensor_slice.py" "useless-super-delegation"
"mindspore/tests/ut/python/optimizer/test_debug_location.py" "super-init-not-called"
"mindspore/tests/ut/python/parallel/" "protected-access"
"mindspore/tests/ut/python/parameter_feature/test_var_grad.py" "bad-super-call"

View File

@ -16,7 +16,6 @@
"""array_ops vmap impl."""
from __future__ import absolute_import
import numpy as np
import mindspore
import mindspore.numpy as mnp
from mindspore import ops
@ -147,14 +146,13 @@ def _get_prefix(indices_shape, axis_size, indices_dtype):
the generated prefix is a Tensor([[[0], [0]],
[[1], [1]]])
"""
if not indices_shape:
raise ValueError("indices_shape is empty in _get_prefix.")
indices_len = len(indices_shape)
if indices_len == 1:
prefix = np.arange(axis_size)
return Tensor(prefix, indices_dtype)
prefix = P.Range()(Tensor(0, indices_dtype), P.Fill()(
indices_dtype, (), axis_size), Tensor(1, indices_dtype))
return prefix
indices_end = indices_len - 1
prefix_shape = ()
@ -169,8 +167,9 @@ def _get_prefix(indices_shape, axis_size, indices_dtype):
else:
expand_shape = expand_shape + (1,)
prefix = np.broadcast_to(np.arange(axis_size).reshape(expand_shape), prefix_shape)
return Tensor(prefix, indices_dtype)
prefix = P.BroadcastTo(prefix_shape)(P.Reshape()(P.Range()(Tensor(
0, indices_dtype), Tensor(axis_size, indices_dtype), Tensor(1, indices_dtype)), expand_shape))
return prefix
@vmap_rules_getters.register(P.Transpose)
@ -352,6 +351,7 @@ def get_unstack_vmap_rule(prim, axis_size):
def get_reshape_vmap_rule(prim, axis_size):
"""VmapRule for `Reshape` operation."""
@constexpr
def get_batch_shape(x_shape, x_dim, target_shape, axis_size):
if x_dim == 0:
@ -364,19 +364,20 @@ def get_reshape_vmap_rule(prim, axis_size):
dim_prod = 1
for i, shp_i in enumerate(target_shape):
if shp_i == -1:
if neg_index != -1:
raise ValueError(f'The shape can only has one -1 at most, but {target_shape}.')
neg_index = i
else:
dim_prod *= shp_i
arr_prod = np.prod(x_shape)
arr_prod = 1
for i in x_shape:
arr_prod *= i
target_shape_list = list(target_shape)
if neg_index != -1:
neg_index_size = int(arr_prod // (dim_prod * axis_size))
target_shape_list[neg_index] = neg_index_size
arr_prod_before_dim = np.prod(x_shape[:x_dim])
arr_prod_before_dim = 1
for i in x_shape[:x_dim]:
arr_prod_before_dim *= i
dim_prod = 1
for i, shp_i in enumerate(target_shape_list, start=1):
dim_prod *= shp_i
@ -565,12 +566,11 @@ def get_scatter_nd_vmap_rule(prim, axis_size):
@constexpr
def _gen_indices_offset(shape, offset):
# original rank(indices.shape) is required >= 2, so indices with batch dim's rank >= 3.
shape = [shape[0]] + [1] * (len(shape) - 2) + [shape[-1]]
val = np.zeros(shape, np.int32) # the dtype will be changed when creating Tensor
val = np.reshape(val, (shape[0], shape[-1]))
shape = (shape[0],) + (1,) * (len(shape) - 2) + (shape[-1],)
val = P.Zeros()((shape[0], shape[-1]), mindspore.int32)
for i in range(shape[0]):
val[i, 0] = i * offset
return np.reshape(val, shape)
return P.Reshape()(val, shape)
if isinstance(prim, str):
prim = Primitive(prim)
@ -591,7 +591,7 @@ def get_scatter_nd_vmap_rule(prim, axis_size):
indices_shape = F.shape(indices)
indices_dtype = F.dtype(indices)
offset_val = _gen_indices_offset(indices_shape, offset)
indices_offset = Tensor(offset_val, indices_dtype)
indices_offset = P.Cast()(offset_val, indices_dtype)
new_indices = P.Add()(indices, indices_offset)
out = prim(new_indices, updates, new_shape)
real_out = P.Reshape()(out, out_shape)

View File

@ -135,9 +135,6 @@ def _get_broadcast_shape(x_shape, dst, axis_size):
x_ndim = len(x_shape)
broadcast_ndim = x_ndim + 1
if dst < -broadcast_ndim or dst >= broadcast_ndim:
_raise_value_error("Destination axis {} is out of bounds for array of dimension"
" [{}, {}).".format(dst, -broadcast_ndim, broadcast_ndim))
if dst < 0:
dst = broadcast_ndim + dst

View File

@ -16,7 +16,6 @@
"""convolution vmap impl"""
from __future__ import absolute_import
import numpy as np
import mindspore.numpy as mnp
from mindspore.ops import constexpr
from mindspore.ops import operations as P
@ -174,10 +173,7 @@ def _reshape_merge_dims(src_dim, dst_dim, target):
@constexpr
def _get_expand_shape(src_dim, dst_size, shape, prim_name):
"""Get new shape for splitting src_dim into dst_size parts."""
dst_size2, remainder = np.divmod(shape[src_dim], dst_size)
if remainder != 0:
_raise_value_error("The remainder of {} / {} should be 0, "
"but got {} in {}.".format(shape[src_dim], dst_size, remainder, prim_name))
dst_size2 = shape[src_dim] // dst_size
new_shape = list(shape)
new_shape[src_dim:(src_dim + 1)] = [dst_size, dst_size2]
return tuple(new_shape)

View File

@ -71,41 +71,35 @@ complex_priority_map = {
complex_types = [mstype.complex64, mstype.complex128]
#TODO: remove comment
@constexpr
def raise_value_error(msg):
"""Constexpr for raise_value_error."""
raise ValueError(msg)
#TODO: remove comment
@constexpr
def raise_index_error(msg):
"""Constexpr for raise_index_error."""
raise IndexError(msg)
#TODO: remove comment
@constexpr
def raise_type_error(msg):
"""Constexpr for raise_type_error."""
raise TypeError(msg)
#TODO: remove comment
@constexpr
def raise_unimplemented_error(msg):
raise NotImplementedError(msg)
#TODO: remove comment
@constexpr
def log_warning(msg):
"""Adds warning to logger."""
logger.warning(msg)
#TODO: remove comment
@constexpr
def check_equal(param1, param2, msg="{},{}"):
"""Checks whether the two parameters are equal or not."""
@ -114,14 +108,12 @@ def check_equal(param1, param2, msg="{},{}"):
return param1
#TODO: remove comment
@constexpr
def make_empty_slice():
"""Creates a empty slice."""
return slice(None, None, None)
#TODO: remove comment
@constexpr
def _deep_list(array_like, dim_size=None):
"""convert nested tuple/list mixtures to pure nested list"""
@ -132,7 +124,6 @@ def _deep_list(array_like, dim_size=None):
return array_like
#TODO: remove comment
@constexpr
def deep_tuple(array_like):
"""convert nested tuple/list mixtures to pure nested tuple"""
@ -165,8 +156,6 @@ def _deep_tensor_to_nparray(array_like):
return array_like
#TODO: remove comment
#@constexpr(run_graph=False)
@constexpr
def check_range(x, dim_size):
if dim_size is None:
@ -178,7 +167,6 @@ def check_range(x, dim_size):
return x
#TODO: remove comment
@constexpr
def make_tensor(a, dtype=mstype.int64, data_shape=None, dim_size=None):
"""
@ -225,7 +213,6 @@ def make_tensor(a, dtype=mstype.int64, data_shape=None, dim_size=None):
tensor_operator_registry.register('make_tensor', make_tensor)
#TODO: remove comment
@constexpr
def judge_data_dim(data_dim, min_data_dim=0, max_data_dim=8):
"""Judges whether the data dim is valid."""
@ -241,7 +228,6 @@ def get_source_shape(data_shape, value_shape):
return value_shape
#TODO: remove comment
@constexpr
def check_tensor_setitem_index(index, element_type=None):
"""Checks tuple index type of tensor assignment."""
@ -268,7 +254,6 @@ def check_tensor_setitem_index(index, element_type=None):
"Index of type '{}' is not supported yet.".format(type(index)))
#TODO: remove comment
@constexpr
def is_same_type(inst, type_):
"""
@ -284,7 +269,6 @@ def is_same_type(inst, type_):
return inst == type_
#TODO: remove comment
@constexpr
def check_valid_dim(dim, name):
"""Checks whether the dim is valid."""
@ -292,7 +276,6 @@ def check_valid_dim(dim, name):
raise ValueError(f"For '{name}', the dimension of inputs must be 1d or 2d, but got {dim}.")
#TODO: remove comment
@constexpr
def judge_index_type(index_type, target_type):
"""Judges whether the index type is valid."""
@ -301,7 +284,6 @@ def judge_index_type(index_type, target_type):
return False
#TODO: remove comment
@constexpr
def judge_indexes_types(dtypes, target_type):
"""Check a tuple of tensor data type."""
@ -315,7 +297,6 @@ def judge_indexes_types(dtypes, target_type):
return True
#TODO: remove comment
@constexpr
def check_type_isinstance(dtype, target_type):
"""Checks whether the dtype is instance of target type."""
@ -324,14 +305,12 @@ def check_type_isinstance(dtype, target_type):
return isinstance(dtype, target_type)
#TODO: remove comment
@constexpr
def check_type_invalid(dtype, target_type):
"""Checks whether the dtype is valid."""
return dtype != target_type and (isinstance(target_type, (list, tuple)) and dtype not in target_type)
#TODO: remove comment
@constexpr
def check_type_valid(dtype, target_type, op_name):
"""Checks whether the dtype is valid."""
@ -343,7 +322,6 @@ def check_type_valid(dtype, target_type, op_name):
f"The '{op_name}' doesn't support '{dtype}' and expect to receive {target_type}.")
#TODO: remove comment
@constexpr
def check_types_valid(dtypes, target_type, op_name):
"""Check a tuple of tensor data type."""
@ -351,7 +329,6 @@ def check_types_valid(dtypes, target_type, op_name):
check_type_valid(dtype, target_type, op_name)
#TODO: remove comment
@constexpr
def get_pos_of_indexes_types(indexes_types, op_name):
"""Separate the position information of tensor and slice and ellipsis from the mixed tensors index."""
@ -403,7 +380,7 @@ def ellipsis2slice(input_, shape):
return tuple(result)
#remove constexpr
@constexpr
def slice2indices(input_slice, shape):
"""
Converts slice to indices.
@ -435,7 +412,6 @@ def slice2indices(input_slice, shape):
return P.Stack(-1)(mesh_arrays)
#TODO: remove comment
@constexpr
def check_indices(indices_size, index):
"""Checks indices whether is empty."""
@ -445,7 +421,6 @@ def check_indices(indices_size, index):
return indices_size
#TODO: remove comment
@constexpr
def check_indices_value_size(indices_size, value_size):
"""Checks if the sizes are already matched."""
@ -459,7 +434,6 @@ def check_indices_value_size(indices_size, value_size):
return value_size
#TODO: remove comment
@constexpr
def tuple_index_type_cnt(types, op_name):
"""count the tensor type of types which contains the tuple elements' type."""
@ -470,7 +444,6 @@ def tuple_index_type_cnt(types, op_name):
return MIXED
#TODO: remove comment
@constexpr
def check_value_elements(types):
"""Judges the type of all elements of the tuple."""
@ -485,7 +458,6 @@ def check_value_elements(types):
return CONTAIN_TENSOR
#TODO: remove comment
@constexpr
def get_index_tensor_dtype(dtype):
"""Check a tuple of tensor data type."""
@ -497,7 +469,6 @@ def get_index_tensor_dtype(dtype):
f"For '{TENSOR_SETITEM}', the index tensor data type '{dtype}' is not supported.")
#TODO: remove comment
@constexpr
def check_tensors_dtype_same(data_dtype, value_dtype, op_name):
"""Check tensors data type same."""
@ -540,7 +511,7 @@ def get_broadcast_shape(x_shape, y_shape, prim_name):
return broadcast_shape
#remove constexpr
@constexpr
def generate_broadcast_shape(shapes, op_name):
"""Generate broadcast shape for a tuple of shape."""
if not shapes:
@ -551,20 +522,20 @@ def generate_broadcast_shape(shapes, op_name):
return tuple(broadcast_shape)
#remove constexpr
@constexpr
def check_two_shapes_need_broadcast(shape_x, shape_y):
"""Check shape_y needs to be broadcast to shape_x."""
return shape_y != shape_x
#remove constexpr
@constexpr
def compute_multiples(origin_shape, broadcast_shape):
"""Compute multiples between origin shape with broadcast shape."""
len_gap = len(broadcast_shape) - len(origin_shape)
return broadcast_shape[0:len_gap] + tuple(map(lambda x, y: x // y, broadcast_shape[len_gap:], tuple(origin_shape)))
#remove constexpr
@constexpr
def convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type):
"""Convert a scalar to a tensor."""
if op_type == SET_ITEM_BY_ONE_TENSOR:
@ -589,7 +560,6 @@ def generate_updates_shape(data_shape, index_shape, op_type, is_dynamic):
return updates_shape
#TODO: remove comment
@constexpr
def transform_slice_to_ele_list(slice_index, dim_len):
"""Transforms slice to element list."""
@ -601,7 +571,6 @@ def transform_slice_to_ele_list(slice_index, dim_len):
return slice_ele_list
#TODO: remove comment
@constexpr
def generate_index_info_from_tuple_of_mixed_tensors(tensor_positions, tensor_indexes_shapes,
slice_shapes, op_name, fancy_position=None):
@ -631,14 +600,12 @@ def _judge_order_continuous(order_sequence):
return True
#TODO: remove comment
@constexpr
def scalar_in_sequence(x, y):
"""Determine whether the scalar in the sequence."""
return x in y
#TODO: remove comment
@constexpr
def get_np_eps(input_dtype):
"""Get numpy eps."""
@ -647,7 +614,6 @@ def get_np_eps(input_dtype):
return float(eps)
#TODO: remove comment
@constexpr
def check_number_index_type(number):
"""Check if it is int or bool number"""
@ -659,7 +625,7 @@ def check_number_index_type(number):
.format(number, type(number)))
#remove constexpr
@constexpr
def get_stride_info_from_slice(data_shape, slice_index):
"""Get stride info from a python slice"""
begin, end, step = get_slice_stride(slice_index, data_shape[0])
@ -673,7 +639,7 @@ def get_stride_info_from_slice(data_shape, slice_index):
return tuple(begin_strides), tuple(end_strides), tuple(step_strides)
#remove constexpr
@constexpr
def get_stride_info_from_integer(data_shape, number):
"""Get stride info from an integer"""
begin_strides = [number]
@ -686,7 +652,6 @@ def get_stride_info_from_integer(data_shape, number):
return tuple(begin_strides), tuple(end_strides), tuple(step_strides)
#TODO: remove comment
@constexpr
def get_slice_stride(index_slice, dim_size):
"""Get slice stride info"""
@ -748,14 +713,12 @@ def get_stride_info_from_tuple(data_shape, tuple_index):
return strides_v, shrink_axis
#TODO: remove comment
@constexpr
def scalar_to_tensor(x):
"""Convert a scalar to a tensor"""
return Tensor(x)
#TODO: remove comment
@constexpr
def unpack(x):
if isinstance(x, (tuple, list)) and len(x) == 1:
@ -763,7 +726,6 @@ def unpack(x):
return x
#TODO: remove comment
@constexpr
def normalize_start(start, dim_size):
"""
@ -779,7 +741,6 @@ def normalize_start(start, dim_size):
return start if start < dim_size else dim_size
#TODO: remove comment
@constexpr
def normalize_stop(stop, dim_size):
"""
@ -797,7 +758,6 @@ def normalize_stop(stop, dim_size):
return stop if stop < dim_size else dim_size
#TODO: remove comment
@constexpr
def get_step_from_slice(input_slice):
"""get step in a slice."""
@ -807,7 +767,6 @@ def get_step_from_slice(input_slice):
return step
#TODO: remove comment
@constexpr
def normalize_slice(input_slice, dim_size):
"""Normalizes start, stop, step in a slice."""
@ -823,7 +782,6 @@ def normalize_slice(input_slice, dim_size):
return start, stop, step
#TODO: remove comment
@constexpr
def tuple_slice(tup, start, end):
"""get sliced tuple from start and end."""
@ -834,7 +792,6 @@ def expanded_shape(shape, expand_size):
return (1,)*expand_size + shape
#TODO: remove comment
@constexpr
def sequence_mul_int(seq, number):
"""
@ -852,14 +809,12 @@ def sequence_mul_int(seq, number):
return seq * number
#TODO: remove comment
@constexpr
def check_in_sequence(x, y):
"""Determine whether the input `x` is in the sequence `y`."""
return x in y
#TODO: remove comment
@constexpr
def is_slice(x):
return isinstance(x, slice)
@ -874,7 +829,6 @@ def filter_expanded_dims(shape, not_expanded_dim):
return tuple(res)
#TODO: remove comment
@constexpr
def sequence_to_index(sequence, dim_size):
"""Transforms sequence to tensor index."""
@ -892,7 +846,7 @@ def sequence_to_index(sequence, dim_size):
return make_tensor(sequence, mstype.int64, None, dim_size)
#remove constexpr
@constexpr
def int_to_index(i, shape):
"""Converts integer to tensor indices."""
dim_size = shape[0]
@ -913,7 +867,6 @@ def int_to_index(i, shape):
return P.Concat(-1)((P.Fill()(mstype.int64, P.Shape()(index)[:-1] + (1,), i), index))
#TODO: remove comment
@constexpr
def rem_not_expanded_dims(idx_advanced, expand_true, tensor_index_ndim, rem_ndim, not_expanded_dim):
"""Adds remaining dimensions not indexed to not_expanded_dim"""
@ -933,13 +886,11 @@ def rem_not_expanded_dims(idx_advanced, expand_true, tensor_index_ndim, rem_ndim
return not_expanded_dim, idx_advanced
#TODO: remove comment
@constexpr
def check_slice_empty(start, stop, step):
return (start - stop)*step >= 0
#TODO: remove comment
@constexpr
def real_axes(ndim_orig, ndim_out, axes_orig):
"""Returns the real axes to be reduced after performing broadcast"""
@ -952,7 +903,7 @@ def real_axes(ndim_orig, ndim_out, axes_orig):
check_axis_valid_const = constexpr(validator.check_axis_valid)
#remove constexpr
@constexpr
def compute_slice_shape(slice_shape, broadcast_shape_len, slice_cnt, fancy_position):
"""Computes slice tensor shapes"""
shape = [1] * len(slice_shape)
@ -961,7 +912,7 @@ def compute_slice_shape(slice_shape, broadcast_shape_len, slice_cnt, fancy_posit
return shape
#remove constexpr
@constexpr
def infer_out_shape(*shapes):
"""
Returns shape of output after broadcasting. Raises ValueError if shapes cannot be broadcast.
@ -977,7 +928,6 @@ def infer_out_shape(*shapes):
return tuple(shape_out)
#TODO: remove comment
@constexpr
def use_copy_slice(tuple_index):
if tuple_index is not None and len(tuple_index) >= 2:
@ -987,7 +937,6 @@ def use_copy_slice(tuple_index):
return False
#TODO: remove comment
@constexpr
def is_ascend():
return context.get_context('device_target') == "Ascend"
@ -998,7 +947,6 @@ def gen_exception_msg(msg_format, *args):
return msg_format.format(*args)
#TODO: remove comment
@constexpr
def get_output_dtype(dtype_1, dtype_2, use_complex=False):
"""Returns output dtype after type promotion."""
@ -1017,7 +965,6 @@ def get_output_dtype(dtype_1, dtype_2, use_complex=False):
return dtype_2
#TODO: remove comment
@constexpr
def promote_binary_dtype(dtype_1, dtype_2):
if dtype_1 == dtype_2:

View File

@ -4412,7 +4412,6 @@ def index_fill(x, axis, index, value):
return index_fill_(x, axis, index, value)
#TODO: remove comment
@constexpr
def _check_check_axis_in_range(axis, ndim):
"""Checks axes are with the bounds of ndim"""

View File

@ -19,12 +19,13 @@ import mindspore.numpy as mnp
from mindspore import Tensor
from mindspore.nn import Cell
import mindspore.context as context
from ....mindspore_test_framework.mindspore_test import mindspore_test
from ....mindspore_test_framework.pipeline.forward.compile_forward \
from ...mindspore_test_framework.mindspore_test import mindspore_test
from ...mindspore_test_framework.pipeline.forward.compile_forward \
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config
context.set_context(mode=context.GRAPH_MODE)
class MeshGrid(Cell):
def construct(self, a, b, c, d):
ret = mnp.meshgrid(a, b, c, d)

View File

@ -20,8 +20,8 @@ from mindspore import Tensor, Parameter
from mindspore import context
from mindspore import dtype as mstype
from mindspore.nn import Cell
from ....mindspore_test_framework.mindspore_test import mindspore_test
from ....mindspore_test_framework.pipeline.forward.compile_forward \
from ...mindspore_test_framework.mindspore_test import mindspore_test
from ...mindspore_test_framework.pipeline.forward.compile_forward \
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config, \
pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception