forked from mindspore-Ecosystem/mindspore
!46764 rectify constexpr in vmap
Merge pull request !46764 from NaCN/constexpr_3
This commit is contained in:
commit
a23dc71bf0
|
@ -128,7 +128,8 @@
|
|||
"mindspore/tests/ut/python/nn/test_cell_method_attribute.py" "no-self-argument"
|
||||
"mindspore/tests/ut/python/onnx/test_onnx.py" "unused-variable"
|
||||
"mindspore/tests/ut/python/ops" "super-init-not-called"
|
||||
"mindspore/tests/ut/python/ops/test_tensor_slice.py" "redefined-outer-name"
|
||||
"mindspore/tests/st/ops/test_tensor_slice.py" "redefined-outer-name"
|
||||
"mindspore/tests/st/ops/test_tensor_slice.py" "useless-super-delegation"
|
||||
"mindspore/tests/ut/python/optimizer/test_debug_location.py" "super-init-not-called"
|
||||
"mindspore/tests/ut/python/parallel/" "protected-access"
|
||||
"mindspore/tests/ut/python/parameter_feature/test_var_grad.py" "bad-super-call"
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
"""array_ops vmap impl."""
|
||||
from __future__ import absolute_import
|
||||
|
||||
import numpy as np
|
||||
import mindspore
|
||||
import mindspore.numpy as mnp
|
||||
from mindspore import ops
|
||||
|
@ -147,14 +146,13 @@ def _get_prefix(indices_shape, axis_size, indices_dtype):
|
|||
the generated prefix is a Tensor([[[0], [0]],
|
||||
[[1], [1]]])
|
||||
"""
|
||||
if not indices_shape:
|
||||
raise ValueError("indices_shape is empty in _get_prefix.")
|
||||
|
||||
indices_len = len(indices_shape)
|
||||
|
||||
if indices_len == 1:
|
||||
prefix = np.arange(axis_size)
|
||||
return Tensor(prefix, indices_dtype)
|
||||
prefix = P.Range()(Tensor(0, indices_dtype), P.Fill()(
|
||||
indices_dtype, (), axis_size), Tensor(1, indices_dtype))
|
||||
return prefix
|
||||
|
||||
indices_end = indices_len - 1
|
||||
prefix_shape = ()
|
||||
|
@ -169,8 +167,9 @@ def _get_prefix(indices_shape, axis_size, indices_dtype):
|
|||
else:
|
||||
expand_shape = expand_shape + (1,)
|
||||
|
||||
prefix = np.broadcast_to(np.arange(axis_size).reshape(expand_shape), prefix_shape)
|
||||
return Tensor(prefix, indices_dtype)
|
||||
prefix = P.BroadcastTo(prefix_shape)(P.Reshape()(P.Range()(Tensor(
|
||||
0, indices_dtype), Tensor(axis_size, indices_dtype), Tensor(1, indices_dtype)), expand_shape))
|
||||
return prefix
|
||||
|
||||
|
||||
@vmap_rules_getters.register(P.Transpose)
|
||||
|
@ -352,6 +351,7 @@ def get_unstack_vmap_rule(prim, axis_size):
|
|||
def get_reshape_vmap_rule(prim, axis_size):
|
||||
"""VmapRule for `Reshape` operation."""
|
||||
|
||||
|
||||
@constexpr
|
||||
def get_batch_shape(x_shape, x_dim, target_shape, axis_size):
|
||||
if x_dim == 0:
|
||||
|
@ -364,19 +364,20 @@ def get_reshape_vmap_rule(prim, axis_size):
|
|||
dim_prod = 1
|
||||
for i, shp_i in enumerate(target_shape):
|
||||
if shp_i == -1:
|
||||
if neg_index != -1:
|
||||
raise ValueError(f'The shape can only has one -1 at most, but {target_shape}.')
|
||||
neg_index = i
|
||||
else:
|
||||
dim_prod *= shp_i
|
||||
arr_prod = np.prod(x_shape)
|
||||
arr_prod = 1
|
||||
for i in x_shape:
|
||||
arr_prod *= i
|
||||
target_shape_list = list(target_shape)
|
||||
if neg_index != -1:
|
||||
neg_index_size = int(arr_prod // (dim_prod * axis_size))
|
||||
target_shape_list[neg_index] = neg_index_size
|
||||
|
||||
arr_prod_before_dim = np.prod(x_shape[:x_dim])
|
||||
|
||||
arr_prod_before_dim = 1
|
||||
for i in x_shape[:x_dim]:
|
||||
arr_prod_before_dim *= i
|
||||
dim_prod = 1
|
||||
for i, shp_i in enumerate(target_shape_list, start=1):
|
||||
dim_prod *= shp_i
|
||||
|
@ -565,12 +566,11 @@ def get_scatter_nd_vmap_rule(prim, axis_size):
|
|||
@constexpr
|
||||
def _gen_indices_offset(shape, offset):
|
||||
# original rank(indices.shape) is required >= 2, so indices with batch dim's rank >= 3.
|
||||
shape = [shape[0]] + [1] * (len(shape) - 2) + [shape[-1]]
|
||||
val = np.zeros(shape, np.int32) # the dtype will be changed when creating Tensor
|
||||
val = np.reshape(val, (shape[0], shape[-1]))
|
||||
shape = (shape[0],) + (1,) * (len(shape) - 2) + (shape[-1],)
|
||||
val = P.Zeros()((shape[0], shape[-1]), mindspore.int32)
|
||||
for i in range(shape[0]):
|
||||
val[i, 0] = i * offset
|
||||
return np.reshape(val, shape)
|
||||
return P.Reshape()(val, shape)
|
||||
|
||||
if isinstance(prim, str):
|
||||
prim = Primitive(prim)
|
||||
|
@ -591,7 +591,7 @@ def get_scatter_nd_vmap_rule(prim, axis_size):
|
|||
indices_shape = F.shape(indices)
|
||||
indices_dtype = F.dtype(indices)
|
||||
offset_val = _gen_indices_offset(indices_shape, offset)
|
||||
indices_offset = Tensor(offset_val, indices_dtype)
|
||||
indices_offset = P.Cast()(offset_val, indices_dtype)
|
||||
new_indices = P.Add()(indices, indices_offset)
|
||||
out = prim(new_indices, updates, new_shape)
|
||||
real_out = P.Reshape()(out, out_shape)
|
||||
|
|
|
@ -135,9 +135,6 @@ def _get_broadcast_shape(x_shape, dst, axis_size):
|
|||
x_ndim = len(x_shape)
|
||||
broadcast_ndim = x_ndim + 1
|
||||
|
||||
if dst < -broadcast_ndim or dst >= broadcast_ndim:
|
||||
_raise_value_error("Destination axis {} is out of bounds for array of dimension"
|
||||
" [{}, {}).".format(dst, -broadcast_ndim, broadcast_ndim))
|
||||
if dst < 0:
|
||||
dst = broadcast_ndim + dst
|
||||
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
"""convolution vmap impl"""
|
||||
from __future__ import absolute_import
|
||||
|
||||
import numpy as np
|
||||
import mindspore.numpy as mnp
|
||||
from mindspore.ops import constexpr
|
||||
from mindspore.ops import operations as P
|
||||
|
@ -174,10 +173,7 @@ def _reshape_merge_dims(src_dim, dst_dim, target):
|
|||
@constexpr
|
||||
def _get_expand_shape(src_dim, dst_size, shape, prim_name):
|
||||
"""Get new shape for splitting src_dim into dst_size parts."""
|
||||
dst_size2, remainder = np.divmod(shape[src_dim], dst_size)
|
||||
if remainder != 0:
|
||||
_raise_value_error("The remainder of {} / {} should be 0, "
|
||||
"but got {} in {}.".format(shape[src_dim], dst_size, remainder, prim_name))
|
||||
dst_size2 = shape[src_dim] // dst_size
|
||||
new_shape = list(shape)
|
||||
new_shape[src_dim:(src_dim + 1)] = [dst_size, dst_size2]
|
||||
return tuple(new_shape)
|
||||
|
|
|
@ -71,41 +71,35 @@ complex_priority_map = {
|
|||
complex_types = [mstype.complex64, mstype.complex128]
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def raise_value_error(msg):
|
||||
"""Constexpr for raise_value_error."""
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def raise_index_error(msg):
|
||||
"""Constexpr for raise_index_error."""
|
||||
raise IndexError(msg)
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def raise_type_error(msg):
|
||||
"""Constexpr for raise_type_error."""
|
||||
raise TypeError(msg)
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def raise_unimplemented_error(msg):
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def log_warning(msg):
|
||||
"""Adds warning to logger."""
|
||||
logger.warning(msg)
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def check_equal(param1, param2, msg="{},{}"):
|
||||
"""Checks whether the two parameters are equal or not."""
|
||||
|
@ -114,14 +108,12 @@ def check_equal(param1, param2, msg="{},{}"):
|
|||
return param1
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def make_empty_slice():
|
||||
"""Creates a empty slice."""
|
||||
return slice(None, None, None)
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def _deep_list(array_like, dim_size=None):
|
||||
"""convert nested tuple/list mixtures to pure nested list"""
|
||||
|
@ -132,7 +124,6 @@ def _deep_list(array_like, dim_size=None):
|
|||
return array_like
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def deep_tuple(array_like):
|
||||
"""convert nested tuple/list mixtures to pure nested tuple"""
|
||||
|
@ -165,8 +156,6 @@ def _deep_tensor_to_nparray(array_like):
|
|||
return array_like
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
#@constexpr(run_graph=False)
|
||||
@constexpr
|
||||
def check_range(x, dim_size):
|
||||
if dim_size is None:
|
||||
|
@ -178,7 +167,6 @@ def check_range(x, dim_size):
|
|||
return x
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def make_tensor(a, dtype=mstype.int64, data_shape=None, dim_size=None):
|
||||
"""
|
||||
|
@ -225,7 +213,6 @@ def make_tensor(a, dtype=mstype.int64, data_shape=None, dim_size=None):
|
|||
tensor_operator_registry.register('make_tensor', make_tensor)
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def judge_data_dim(data_dim, min_data_dim=0, max_data_dim=8):
|
||||
"""Judges whether the data dim is valid."""
|
||||
|
@ -241,7 +228,6 @@ def get_source_shape(data_shape, value_shape):
|
|||
return value_shape
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def check_tensor_setitem_index(index, element_type=None):
|
||||
"""Checks tuple index type of tensor assignment."""
|
||||
|
@ -268,7 +254,6 @@ def check_tensor_setitem_index(index, element_type=None):
|
|||
"Index of type '{}' is not supported yet.".format(type(index)))
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def is_same_type(inst, type_):
|
||||
"""
|
||||
|
@ -284,7 +269,6 @@ def is_same_type(inst, type_):
|
|||
return inst == type_
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def check_valid_dim(dim, name):
|
||||
"""Checks whether the dim is valid."""
|
||||
|
@ -292,7 +276,6 @@ def check_valid_dim(dim, name):
|
|||
raise ValueError(f"For '{name}', the dimension of inputs must be 1d or 2d, but got {dim}.")
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def judge_index_type(index_type, target_type):
|
||||
"""Judges whether the index type is valid."""
|
||||
|
@ -301,7 +284,6 @@ def judge_index_type(index_type, target_type):
|
|||
return False
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def judge_indexes_types(dtypes, target_type):
|
||||
"""Check a tuple of tensor data type."""
|
||||
|
@ -315,7 +297,6 @@ def judge_indexes_types(dtypes, target_type):
|
|||
return True
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def check_type_isinstance(dtype, target_type):
|
||||
"""Checks whether the dtype is instance of target type."""
|
||||
|
@ -324,14 +305,12 @@ def check_type_isinstance(dtype, target_type):
|
|||
return isinstance(dtype, target_type)
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def check_type_invalid(dtype, target_type):
|
||||
"""Checks whether the dtype is valid."""
|
||||
return dtype != target_type and (isinstance(target_type, (list, tuple)) and dtype not in target_type)
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def check_type_valid(dtype, target_type, op_name):
|
||||
"""Checks whether the dtype is valid."""
|
||||
|
@ -343,7 +322,6 @@ def check_type_valid(dtype, target_type, op_name):
|
|||
f"The '{op_name}' doesn't support '{dtype}' and expect to receive {target_type}.")
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def check_types_valid(dtypes, target_type, op_name):
|
||||
"""Check a tuple of tensor data type."""
|
||||
|
@ -351,7 +329,6 @@ def check_types_valid(dtypes, target_type, op_name):
|
|||
check_type_valid(dtype, target_type, op_name)
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def get_pos_of_indexes_types(indexes_types, op_name):
|
||||
"""Separate the position information of tensor and slice and ellipsis from the mixed tensors index."""
|
||||
|
@ -403,7 +380,7 @@ def ellipsis2slice(input_, shape):
|
|||
return tuple(result)
|
||||
|
||||
|
||||
#remove constexpr
|
||||
@constexpr
|
||||
def slice2indices(input_slice, shape):
|
||||
"""
|
||||
Converts slice to indices.
|
||||
|
@ -435,7 +412,6 @@ def slice2indices(input_slice, shape):
|
|||
return P.Stack(-1)(mesh_arrays)
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def check_indices(indices_size, index):
|
||||
"""Checks indices whether is empty."""
|
||||
|
@ -445,7 +421,6 @@ def check_indices(indices_size, index):
|
|||
return indices_size
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def check_indices_value_size(indices_size, value_size):
|
||||
"""Checks if the sizes are already matched."""
|
||||
|
@ -459,7 +434,6 @@ def check_indices_value_size(indices_size, value_size):
|
|||
return value_size
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def tuple_index_type_cnt(types, op_name):
|
||||
"""count the tensor type of types which contains the tuple elements' type."""
|
||||
|
@ -470,7 +444,6 @@ def tuple_index_type_cnt(types, op_name):
|
|||
return MIXED
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def check_value_elements(types):
|
||||
"""Judges the type of all elements of the tuple."""
|
||||
|
@ -485,7 +458,6 @@ def check_value_elements(types):
|
|||
return CONTAIN_TENSOR
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def get_index_tensor_dtype(dtype):
|
||||
"""Check a tuple of tensor data type."""
|
||||
|
@ -497,7 +469,6 @@ def get_index_tensor_dtype(dtype):
|
|||
f"For '{TENSOR_SETITEM}', the index tensor data type '{dtype}' is not supported.")
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def check_tensors_dtype_same(data_dtype, value_dtype, op_name):
|
||||
"""Check tensors data type same."""
|
||||
|
@ -540,7 +511,7 @@ def get_broadcast_shape(x_shape, y_shape, prim_name):
|
|||
return broadcast_shape
|
||||
|
||||
|
||||
#remove constexpr
|
||||
@constexpr
|
||||
def generate_broadcast_shape(shapes, op_name):
|
||||
"""Generate broadcast shape for a tuple of shape."""
|
||||
if not shapes:
|
||||
|
@ -551,20 +522,20 @@ def generate_broadcast_shape(shapes, op_name):
|
|||
return tuple(broadcast_shape)
|
||||
|
||||
|
||||
#remove constexpr
|
||||
@constexpr
|
||||
def check_two_shapes_need_broadcast(shape_x, shape_y):
|
||||
"""Check shape_y needs to be broadcast to shape_x."""
|
||||
return shape_y != shape_x
|
||||
|
||||
|
||||
#remove constexpr
|
||||
@constexpr
|
||||
def compute_multiples(origin_shape, broadcast_shape):
|
||||
"""Compute multiples between origin shape with broadcast shape."""
|
||||
len_gap = len(broadcast_shape) - len(origin_shape)
|
||||
return broadcast_shape[0:len_gap] + tuple(map(lambda x, y: x // y, broadcast_shape[len_gap:], tuple(origin_shape)))
|
||||
|
||||
|
||||
#remove constexpr
|
||||
@constexpr
|
||||
def convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type):
|
||||
"""Convert a scalar to a tensor."""
|
||||
if op_type == SET_ITEM_BY_ONE_TENSOR:
|
||||
|
@ -589,7 +560,6 @@ def generate_updates_shape(data_shape, index_shape, op_type, is_dynamic):
|
|||
return updates_shape
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def transform_slice_to_ele_list(slice_index, dim_len):
|
||||
"""Transforms slice to element list."""
|
||||
|
@ -601,7 +571,6 @@ def transform_slice_to_ele_list(slice_index, dim_len):
|
|||
return slice_ele_list
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def generate_index_info_from_tuple_of_mixed_tensors(tensor_positions, tensor_indexes_shapes,
|
||||
slice_shapes, op_name, fancy_position=None):
|
||||
|
@ -631,14 +600,12 @@ def _judge_order_continuous(order_sequence):
|
|||
return True
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def scalar_in_sequence(x, y):
|
||||
"""Determine whether the scalar in the sequence."""
|
||||
return x in y
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def get_np_eps(input_dtype):
|
||||
"""Get numpy eps."""
|
||||
|
@ -647,7 +614,6 @@ def get_np_eps(input_dtype):
|
|||
return float(eps)
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def check_number_index_type(number):
|
||||
"""Check if it is int or bool number"""
|
||||
|
@ -659,7 +625,7 @@ def check_number_index_type(number):
|
|||
.format(number, type(number)))
|
||||
|
||||
|
||||
#remove constexpr
|
||||
@constexpr
|
||||
def get_stride_info_from_slice(data_shape, slice_index):
|
||||
"""Get stride info from a python slice"""
|
||||
begin, end, step = get_slice_stride(slice_index, data_shape[0])
|
||||
|
@ -673,7 +639,7 @@ def get_stride_info_from_slice(data_shape, slice_index):
|
|||
return tuple(begin_strides), tuple(end_strides), tuple(step_strides)
|
||||
|
||||
|
||||
#remove constexpr
|
||||
@constexpr
|
||||
def get_stride_info_from_integer(data_shape, number):
|
||||
"""Get stride info from an integer"""
|
||||
begin_strides = [number]
|
||||
|
@ -686,7 +652,6 @@ def get_stride_info_from_integer(data_shape, number):
|
|||
return tuple(begin_strides), tuple(end_strides), tuple(step_strides)
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def get_slice_stride(index_slice, dim_size):
|
||||
"""Get slice stride info"""
|
||||
|
@ -748,14 +713,12 @@ def get_stride_info_from_tuple(data_shape, tuple_index):
|
|||
return strides_v, shrink_axis
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def scalar_to_tensor(x):
|
||||
"""Convert a scalar to a tensor"""
|
||||
return Tensor(x)
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def unpack(x):
|
||||
if isinstance(x, (tuple, list)) and len(x) == 1:
|
||||
|
@ -763,7 +726,6 @@ def unpack(x):
|
|||
return x
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def normalize_start(start, dim_size):
|
||||
"""
|
||||
|
@ -779,7 +741,6 @@ def normalize_start(start, dim_size):
|
|||
return start if start < dim_size else dim_size
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def normalize_stop(stop, dim_size):
|
||||
"""
|
||||
|
@ -797,7 +758,6 @@ def normalize_stop(stop, dim_size):
|
|||
return stop if stop < dim_size else dim_size
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def get_step_from_slice(input_slice):
|
||||
"""get step in a slice."""
|
||||
|
@ -807,7 +767,6 @@ def get_step_from_slice(input_slice):
|
|||
return step
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def normalize_slice(input_slice, dim_size):
|
||||
"""Normalizes start, stop, step in a slice."""
|
||||
|
@ -823,7 +782,6 @@ def normalize_slice(input_slice, dim_size):
|
|||
return start, stop, step
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def tuple_slice(tup, start, end):
|
||||
"""get sliced tuple from start and end."""
|
||||
|
@ -834,7 +792,6 @@ def expanded_shape(shape, expand_size):
|
|||
return (1,)*expand_size + shape
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def sequence_mul_int(seq, number):
|
||||
"""
|
||||
|
@ -852,14 +809,12 @@ def sequence_mul_int(seq, number):
|
|||
return seq * number
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def check_in_sequence(x, y):
|
||||
"""Determine whether the input `x` is in the sequence `y`."""
|
||||
return x in y
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def is_slice(x):
|
||||
return isinstance(x, slice)
|
||||
|
@ -874,7 +829,6 @@ def filter_expanded_dims(shape, not_expanded_dim):
|
|||
return tuple(res)
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def sequence_to_index(sequence, dim_size):
|
||||
"""Transforms sequence to tensor index."""
|
||||
|
@ -892,7 +846,7 @@ def sequence_to_index(sequence, dim_size):
|
|||
return make_tensor(sequence, mstype.int64, None, dim_size)
|
||||
|
||||
|
||||
#remove constexpr
|
||||
@constexpr
|
||||
def int_to_index(i, shape):
|
||||
"""Converts integer to tensor indices."""
|
||||
dim_size = shape[0]
|
||||
|
@ -913,7 +867,6 @@ def int_to_index(i, shape):
|
|||
return P.Concat(-1)((P.Fill()(mstype.int64, P.Shape()(index)[:-1] + (1,), i), index))
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def rem_not_expanded_dims(idx_advanced, expand_true, tensor_index_ndim, rem_ndim, not_expanded_dim):
|
||||
"""Adds remaining dimensions not indexed to not_expanded_dim"""
|
||||
|
@ -933,13 +886,11 @@ def rem_not_expanded_dims(idx_advanced, expand_true, tensor_index_ndim, rem_ndim
|
|||
return not_expanded_dim, idx_advanced
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def check_slice_empty(start, stop, step):
|
||||
return (start - stop)*step >= 0
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def real_axes(ndim_orig, ndim_out, axes_orig):
|
||||
"""Returns the real axes to be reduced after performing broadcast"""
|
||||
|
@ -952,7 +903,7 @@ def real_axes(ndim_orig, ndim_out, axes_orig):
|
|||
check_axis_valid_const = constexpr(validator.check_axis_valid)
|
||||
|
||||
|
||||
#remove constexpr
|
||||
@constexpr
|
||||
def compute_slice_shape(slice_shape, broadcast_shape_len, slice_cnt, fancy_position):
|
||||
"""Computes slice tensor shapes"""
|
||||
shape = [1] * len(slice_shape)
|
||||
|
@ -961,7 +912,7 @@ def compute_slice_shape(slice_shape, broadcast_shape_len, slice_cnt, fancy_posit
|
|||
return shape
|
||||
|
||||
|
||||
#remove constexpr
|
||||
@constexpr
|
||||
def infer_out_shape(*shapes):
|
||||
"""
|
||||
Returns shape of output after broadcasting. Raises ValueError if shapes cannot be broadcast.
|
||||
|
@ -977,7 +928,6 @@ def infer_out_shape(*shapes):
|
|||
return tuple(shape_out)
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def use_copy_slice(tuple_index):
|
||||
if tuple_index is not None and len(tuple_index) >= 2:
|
||||
|
@ -987,7 +937,6 @@ def use_copy_slice(tuple_index):
|
|||
return False
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def is_ascend():
|
||||
return context.get_context('device_target') == "Ascend"
|
||||
|
@ -998,7 +947,6 @@ def gen_exception_msg(msg_format, *args):
|
|||
return msg_format.format(*args)
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def get_output_dtype(dtype_1, dtype_2, use_complex=False):
|
||||
"""Returns output dtype after type promotion."""
|
||||
|
@ -1017,7 +965,6 @@ def get_output_dtype(dtype_1, dtype_2, use_complex=False):
|
|||
return dtype_2
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def promote_binary_dtype(dtype_1, dtype_2):
|
||||
if dtype_1 == dtype_2:
|
||||
|
|
|
@ -4412,7 +4412,6 @@ def index_fill(x, axis, index, value):
|
|||
return index_fill_(x, axis, index, value)
|
||||
|
||||
|
||||
#TODO: remove comment
|
||||
@constexpr
|
||||
def _check_check_axis_in_range(axis, ndim):
|
||||
"""Checks axes are with the bounds of ndim"""
|
||||
|
|
|
@ -19,12 +19,13 @@ import mindspore.numpy as mnp
|
|||
from mindspore import Tensor
|
||||
from mindspore.nn import Cell
|
||||
import mindspore.context as context
|
||||
from ....mindspore_test_framework.mindspore_test import mindspore_test
|
||||
from ....mindspore_test_framework.pipeline.forward.compile_forward \
|
||||
from ...mindspore_test_framework.mindspore_test import mindspore_test
|
||||
from ...mindspore_test_framework.pipeline.forward.compile_forward \
|
||||
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
class MeshGrid(Cell):
|
||||
def construct(self, a, b, c, d):
|
||||
ret = mnp.meshgrid(a, b, c, d)
|
|
@ -20,8 +20,8 @@ from mindspore import Tensor, Parameter
|
|||
from mindspore import context
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore.nn import Cell
|
||||
from ....mindspore_test_framework.mindspore_test import mindspore_test
|
||||
from ....mindspore_test_framework.pipeline.forward.compile_forward \
|
||||
from ...mindspore_test_framework.mindspore_test import mindspore_test
|
||||
from ...mindspore_test_framework.pipeline.forward.compile_forward \
|
||||
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config, \
|
||||
pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception
|
||||
|
Loading…
Reference in New Issue