forked from mindspore-Ecosystem/mindspore
!11205 master_tensor_getitem_debug
From: @yepei6 Reviewed-by: @kingxian,@zh_qh Signed-off-by: @kingxian
This commit is contained in:
commit
0018070b54
|
@ -147,13 +147,15 @@ def judge_index_type(index_type, target_type):
|
|||
@constexpr
|
||||
def check_type_valid(dtype, target_type, op_name):
|
||||
if dtype != target_type and (isinstance(target_type, (list, tuple)) and dtype not in target_type):
|
||||
raise TypeError(f"The '{op_name}' doesn't supoort {dtype}' and expecte to receive {target_type}.")
|
||||
raise TypeError(
|
||||
f"The '{op_name}' doesn't supoort {dtype}' and expecte to receive {target_type}.")
|
||||
|
||||
|
||||
@constexpr
|
||||
def check_index_type_valid(dtype, target_type, op_name):
|
||||
if dtype != target_type and (isinstance(target_type, (list, tuple)) and dtype not in target_type):
|
||||
raise IndexError(f"The '{op_name}' doesn't supoort {dtype}' and expecte to receive {target_type}.")
|
||||
raise IndexError(
|
||||
f"The '{op_name}' doesn't supoort {dtype}' and expecte to receive {target_type}.")
|
||||
|
||||
|
||||
@constexpr
|
||||
|
@ -189,7 +191,8 @@ def get_pos_of_indexes_types(indexes_types, op_name):
|
|||
raise IndexError(f"For '{op_name}', the index elements only support "
|
||||
f"'Tensor', 'int32', 'int64', 'Slice', 'Ellipsis', but got {index_type}.")
|
||||
if len(ellipsis_positions) > 1:
|
||||
raise IndexError(f"For '{op_name}, an index can only have a single ellipsis('...')")
|
||||
raise IndexError(
|
||||
f"For '{op_name}, an index can only have a single ellipsis('...')")
|
||||
|
||||
return slice_positions, ellipsis_positions, none_positions, int_positions, bool_positions, \
|
||||
tensor_positions, sequence_positions
|
||||
|
@ -260,7 +263,7 @@ def ellipsis2slice(input_, shape):
|
|||
return tuple(result)
|
||||
|
||||
|
||||
@ constexpr
|
||||
@constexpr
|
||||
def slice2indices(input_slices, shape):
|
||||
"""
|
||||
Converts slice to indices.
|
||||
|
@ -285,7 +288,7 @@ def slice2indices(input_slices, shape):
|
|||
return ravel
|
||||
|
||||
|
||||
@ constexpr
|
||||
@constexpr
|
||||
def check_indices(indices_size, index):
|
||||
"""Checks indices whether is empty."""
|
||||
if indices_size < 1:
|
||||
|
@ -294,7 +297,7 @@ def check_indices(indices_size, index):
|
|||
return indices_size
|
||||
|
||||
|
||||
@ constexpr
|
||||
@constexpr
|
||||
def check_indices_value_size(indices_size, value_size):
|
||||
"""Checks if the sizes are already matched."""
|
||||
if value_size < 1:
|
||||
|
@ -307,7 +310,7 @@ def check_indices_value_size(indices_size, value_size):
|
|||
return value_size
|
||||
|
||||
|
||||
@ constexpr
|
||||
@constexpr
|
||||
def integer_to_indices(index, shape):
|
||||
"""Converts int or tuple[int] to indices."""
|
||||
size = reduce(lambda x, y: x * y, shape)
|
||||
|
@ -317,7 +320,7 @@ def integer_to_indices(index, shape):
|
|||
return Tensor(value, dtype=mstype.int32)
|
||||
|
||||
|
||||
@ constexpr
|
||||
@constexpr
|
||||
def tuple_element_is_int(indexs):
|
||||
"""Judges tuple element type."""
|
||||
if not indexs:
|
||||
|
@ -330,18 +333,19 @@ def tuple_element_is_int(indexs):
|
|||
return False
|
||||
|
||||
|
||||
@ constexpr
|
||||
@constexpr
|
||||
def tuple_index_int_cnt(types, op_name):
|
||||
"""count the int type of types which contains the tuple elements' type."""
|
||||
int_cnt = sum(isinstance(ele, mstype.Int) for ele in types)
|
||||
return ALL_INT if int_cnt == len(types) else NO_INT if int_cnt == 0 else CONTAIN_INT
|
||||
|
||||
|
||||
@ constexpr
|
||||
@constexpr
|
||||
def tuple_index_type_cnt(types, op_name):
|
||||
"""count the tensor type of types which contains the tuple elements' type."""
|
||||
tensor_cnt = sum(isinstance(ele, mstype.tensor_type) for ele in types)
|
||||
basic_cnt = sum(isinstance(ele, (mstype.Int, mstype.Ellipsis_, mstype.Slice)) for ele in types)
|
||||
basic_cnt = sum(isinstance(
|
||||
ele, (mstype.Int, mstype.Ellipsis_, mstype.Slice)) for ele in types)
|
||||
if tensor_cnt == len(types):
|
||||
return ALL_TENSOR
|
||||
if basic_cnt == len(types):
|
||||
|
@ -349,7 +353,7 @@ def tuple_index_type_cnt(types, op_name):
|
|||
return MIXED
|
||||
|
||||
|
||||
@ constexpr
|
||||
@constexpr
|
||||
def check_value_elements(data_dtype, types):
|
||||
"""Judges the type of all elements of the tuple."""
|
||||
tensors_number = 0
|
||||
|
@ -377,10 +381,10 @@ def check_value_elements(data_dtype, types):
|
|||
# TODO to del
|
||||
|
||||
|
||||
@ constexpr
|
||||
@constexpr
|
||||
def get_index_tensor_dtype(dtype):
|
||||
"""Check a tuple of tensor data type."""
|
||||
if dtype == mstype.int32:
|
||||
if dtype in mstype.int_type:
|
||||
return INT_
|
||||
if dtype == mstype.bool_:
|
||||
return BOOL_
|
||||
|
@ -389,7 +393,7 @@ def get_index_tensor_dtype(dtype):
|
|||
|
||||
|
||||
# TODO to del
|
||||
@ constexpr
|
||||
@constexpr
|
||||
def check_index_tensors_dtype(indexes_types, op_name):
|
||||
"""Check a tuple of tensor data type."""
|
||||
for index_type in indexes_types:
|
||||
|
@ -400,7 +404,7 @@ def check_index_tensors_dtype(indexes_types, op_name):
|
|||
|
||||
|
||||
# TODO to del
|
||||
@ constexpr
|
||||
@constexpr
|
||||
def check_index_tensor_dtype(index_type, op_name):
|
||||
"""Check a tensor data type."""
|
||||
if index_type in (mstype.int32, mstype.int64):
|
||||
|
@ -410,7 +414,7 @@ def check_index_tensor_dtype(index_type, op_name):
|
|||
|
||||
|
||||
# TODO to del
|
||||
@ constexpr
|
||||
@constexpr
|
||||
def check_tensors_dtype_same(data_dtype, value_dtype, op_name):
|
||||
"""Check tensors data type same."""
|
||||
if value_dtype == data_dtype:
|
||||
|
@ -419,7 +423,7 @@ def check_tensors_dtype_same(data_dtype, value_dtype, op_name):
|
|||
f"is not consistent with assigned tensor data type {data_dtype}.")
|
||||
|
||||
|
||||
@ constexpr
|
||||
@constexpr
|
||||
def generate_broadcast_shape(shapes, op_name):
|
||||
"""Generate broadcast shape for a tuple of shape."""
|
||||
if not shapes:
|
||||
|
@ -428,13 +432,14 @@ def generate_broadcast_shape(shapes, op_name):
|
|||
for i, shape in enumerate(shapes):
|
||||
logger.debug(f"Broadcasts the {i}th tensor, the shape is {shape}.")
|
||||
try:
|
||||
broadcast_shape = op_utils.get_broadcast_shape(broadcast_shape, shape, op_name)
|
||||
broadcast_shape = op_utils.get_broadcast_shape(
|
||||
broadcast_shape, shape, op_name)
|
||||
except ValueError as ex:
|
||||
raise IndexError(ex)
|
||||
return tuple(broadcast_shape)
|
||||
|
||||
|
||||
@ constexpr
|
||||
@constexpr
|
||||
def check_two_shapes_need_broadcast(shape_x, shape_y):
|
||||
"""Check two shapes need broadcast."""
|
||||
error = ValueError(f"For 'tensor setitem with tensor', the value tensor shape "
|
||||
|
@ -451,14 +456,14 @@ def check_two_shapes_need_broadcast(shape_x, shape_y):
|
|||
return True
|
||||
|
||||
|
||||
@ constexpr
|
||||
@constexpr
|
||||
def compute_multiples(origin_shape, broadcast_shape):
|
||||
"""Compute multiples between origin shape with broadcast shape."""
|
||||
len_gap = len(broadcast_shape) - len(origin_shape)
|
||||
return broadcast_shape[0:len_gap] + tuple(map(lambda x, y: x // y, broadcast_shape[len_gap:], origin_shape))
|
||||
|
||||
|
||||
@ constexpr
|
||||
@constexpr
|
||||
def compute_new_shape(origin_shape, indexes_shapes_info):
|
||||
"""Compute new shape between origin shape with final shape."""
|
||||
new_shape = []
|
||||
|
@ -470,21 +475,22 @@ def compute_new_shape(origin_shape, indexes_shapes_info):
|
|||
return tuple(new_shape)
|
||||
|
||||
|
||||
@ constexpr
|
||||
@constexpr
|
||||
def check_sequence_index_type(sequence_index, op_name):
|
||||
"""check if the item's type of list_index is bool or int"""
|
||||
if not all([isinstance(index, (int, bool)) for index in sequence_index]):
|
||||
raise IndexError(f"In the {op_name} operation, only support 'integer' or 'boolean' array(list/tuple), "
|
||||
f"but got {type(index)} in array")
|
||||
for index in sequence_index:
|
||||
if not isinstance(index, int):
|
||||
raise IndexError(f"In the {op_name} operation, only support 'inter' or 'boolean' array(list/tuple), "
|
||||
f"but got {type(index)} in array.")
|
||||
|
||||
|
||||
@ constexpr
|
||||
@constexpr
|
||||
def convert_int_to_slice(tuple_index):
|
||||
tuple_index_new = tuple(slice(i, i+1, 1) for i in tuple_index)
|
||||
return tuple_index_new
|
||||
|
||||
|
||||
@ constexpr
|
||||
@constexpr
|
||||
def check_and_transform_int_index(index, shape, op_name):
|
||||
if index < -shape or index >= shape:
|
||||
raise IndexError(f"In the \"{op_name}\", the index should in the range [-{shape}, {shape-1}] to fit "
|
||||
|
@ -494,16 +500,20 @@ def check_and_transform_int_index(index, shape, op_name):
|
|||
return index
|
||||
|
||||
|
||||
@ constexpr
|
||||
@constexpr
|
||||
def transform_sequence_index(sequence_index, shape, op_name):
|
||||
"""transform list or tuple with integer and boolean to tuple with integer index"""
|
||||
bool_count = len(list(filter(lambda index: isinstance(index, bool), sequence_index)))
|
||||
int_count = len(list(filter(lambda index: isinstance(index, int), sequence_index)))-bool_count
|
||||
bool_count = len(
|
||||
list(filter(lambda index: isinstance(index, bool), sequence_index)))
|
||||
int_count = len(
|
||||
list(filter(lambda index: isinstance(index, int), sequence_index)))-bool_count
|
||||
if int_count == 0:
|
||||
if bool_count == shape:
|
||||
list_index = list(filter(lambda i: sequence_index[i], range(bool_count)))
|
||||
list_index = list(
|
||||
filter(lambda i: sequence_index[i], range(bool_count)))
|
||||
else:
|
||||
raise IndexError("The boolean array should have the same length with the corresponding dimensiton")
|
||||
raise IndexError(
|
||||
"The boolean array should have the same length with the corresponding dimensiton")
|
||||
else:
|
||||
list_index = [int(index) for index in sequence_index]
|
||||
for i, index in enumerate(list_index):
|
||||
|
@ -512,7 +522,7 @@ def transform_sequence_index(sequence_index, shape, op_name):
|
|||
return sub_tuple_index
|
||||
|
||||
|
||||
@ constexpr
|
||||
@constexpr
|
||||
def convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_name):
|
||||
"""Convert a slice to a tensor."""
|
||||
shape = []
|
||||
|
@ -540,7 +550,7 @@ def convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_n
|
|||
return tensor
|
||||
|
||||
|
||||
@ constexpr
|
||||
@constexpr
|
||||
def check_shapes_same(value_shapes, op_name):
|
||||
"""Check if the shapes in the tuple are consistent."""
|
||||
for i, shape in enumerate(value_shapes):
|
||||
|
@ -550,7 +560,7 @@ def check_shapes_same(value_shapes, op_name):
|
|||
return True
|
||||
|
||||
|
||||
@ constexpr
|
||||
@constexpr
|
||||
def convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type):
|
||||
"""Convert a scalar to a tensor."""
|
||||
if op_type == SET_ITEM_BY_ONE_TENSOR:
|
||||
|
@ -563,7 +573,7 @@ def convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_ty
|
|||
f" is not consistent with the assigned tensor data type {data_dtype}.")
|
||||
|
||||
|
||||
@ constexpr
|
||||
@constexpr
|
||||
def convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value, op_type):
|
||||
"""Convert a tuple of scalar to a tensor."""
|
||||
updates_shape = generate_updates_shape(data_shape, index_shape, op_type)
|
||||
|
@ -575,7 +585,7 @@ def convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value
|
|||
return Tensor(np.tile(array, reps))
|
||||
|
||||
|
||||
@ constexpr
|
||||
@constexpr
|
||||
def generate_updates_shape(data_shape, index_shape, op_type):
|
||||
"""Generate updates shape for 'tensor setitem'."""
|
||||
if op_type == SET_ITEM_BY_ONE_TENSOR:
|
||||
|
@ -585,7 +595,7 @@ def generate_updates_shape(data_shape, index_shape, op_type):
|
|||
return updates_shape
|
||||
|
||||
|
||||
@ constexpr
|
||||
@constexpr
|
||||
def check_tuple_index_len(data_rank, tuple_index_len, op_name):
|
||||
"""Check if the number of index tensor exceeds the dimension of the operated tensor."""
|
||||
if tuple_index_len <= data_rank:
|
||||
|
@ -594,7 +604,7 @@ def check_tuple_index_len(data_rank, tuple_index_len, op_name):
|
|||
f"is greater than the dimension {data_rank} of the operated tensor.")
|
||||
|
||||
|
||||
@ constexpr
|
||||
@constexpr
|
||||
def generate_index_info_from_tuple_of_mixed_tensors(data_shape, indexes_types, tensor_indexes_shapes,
|
||||
tensor_indexes_dtypes, slice_indexes, op_name):
|
||||
"""
|
||||
|
@ -694,14 +704,14 @@ def scalar_in_sequence(x, y):
|
|||
return False
|
||||
|
||||
|
||||
@ constexpr
|
||||
@constexpr
|
||||
def get_np_eps(input_dtype):
|
||||
nptype = mstype.dtype_to_nptype(input_dtype)
|
||||
eps = np.finfo(nptype).eps
|
||||
return float(eps)
|
||||
|
||||
|
||||
@ constexpr
|
||||
@constexpr
|
||||
def check_number_index_type(number):
|
||||
"""Check if it is int or bool number"""
|
||||
if isinstance(number, bool):
|
||||
|
@ -712,7 +722,7 @@ def check_number_index_type(number):
|
|||
.format(number, type(number)))
|
||||
|
||||
|
||||
@ constexpr
|
||||
@constexpr
|
||||
def get_stride_info_from_slice(data_shape, slice_index):
|
||||
"""Get stride info from a python slice"""
|
||||
begin, end, step = get_slice_stride(data_shape[0], slice_index)
|
||||
|
@ -726,7 +736,7 @@ def get_stride_info_from_slice(data_shape, slice_index):
|
|||
return tuple(begin_strides), tuple(end_strides), tuple(step_strides)
|
||||
|
||||
|
||||
@ constexpr
|
||||
@constexpr
|
||||
def get_stride_info_from_integer(data_shape, number):
|
||||
"""Get stride info from a integer"""
|
||||
begin_strides = [number]
|
||||
|
@ -752,7 +762,7 @@ def get_slice_stride(dim_size, index_slice):
|
|||
return start, stop, step
|
||||
|
||||
|
||||
@ constexpr
|
||||
@constexpr
|
||||
def get_stride_info_from_tuple(data_shape, tuple_index):
|
||||
"""Get stride info from a tuple"""
|
||||
begin_strides, end_strides, step_strides = [], [], []
|
||||
|
@ -792,14 +802,14 @@ def get_stride_info_from_tuple(data_shape, tuple_index):
|
|||
return tuple(begin_strides), tuple(end_strides), tuple(step_strides), shrink_axis
|
||||
|
||||
|
||||
@ constexpr
|
||||
@constexpr
|
||||
def mstype_eq(x, y):
|
||||
if x == y:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@ constexpr
|
||||
@constexpr
|
||||
def scalar_to_tensor(x):
|
||||
"""Convert a scalar to a tensor"""
|
||||
return Tensor(x)
|
||||
|
|
Loading…
Reference in New Issue