forked from mindspore-Ecosystem/mindspore
!11205 master_tensor_getitem_debug
From: @yepei6 Reviewed-by: @kingxian,@zh_qh Signed-off-by: @kingxian
This commit is contained in:
commit
0018070b54
|
@ -147,13 +147,15 @@ def judge_index_type(index_type, target_type):
|
|||
@constexpr
|
||||
def check_type_valid(dtype, target_type, op_name):
|
||||
if dtype != target_type and (isinstance(target_type, (list, tuple)) and dtype not in target_type):
|
||||
raise TypeError(f"The '{op_name}' doesn't supoort {dtype}' and expecte to receive {target_type}.")
|
||||
raise TypeError(
|
||||
f"The '{op_name}' doesn't supoort {dtype}' and expecte to receive {target_type}.")
|
||||
|
||||
|
||||
@constexpr
|
||||
def check_index_type_valid(dtype, target_type, op_name):
|
||||
if dtype != target_type and (isinstance(target_type, (list, tuple)) and dtype not in target_type):
|
||||
raise IndexError(f"The '{op_name}' doesn't supoort {dtype}' and expecte to receive {target_type}.")
|
||||
raise IndexError(
|
||||
f"The '{op_name}' doesn't supoort {dtype}' and expecte to receive {target_type}.")
|
||||
|
||||
|
||||
@constexpr
|
||||
|
@ -189,7 +191,8 @@ def get_pos_of_indexes_types(indexes_types, op_name):
|
|||
raise IndexError(f"For '{op_name}', the index elements only support "
|
||||
f"'Tensor', 'int32', 'int64', 'Slice', 'Ellipsis', but got {index_type}.")
|
||||
if len(ellipsis_positions) > 1:
|
||||
raise IndexError(f"For '{op_name}, an index can only have a single ellipsis('...')")
|
||||
raise IndexError(
|
||||
f"For '{op_name}, an index can only have a single ellipsis('...')")
|
||||
|
||||
return slice_positions, ellipsis_positions, none_positions, int_positions, bool_positions, \
|
||||
tensor_positions, sequence_positions
|
||||
|
@ -341,7 +344,8 @@ def tuple_index_int_cnt(types, op_name):
|
|||
def tuple_index_type_cnt(types, op_name):
|
||||
"""count the tensor type of types which contains the tuple elements' type."""
|
||||
tensor_cnt = sum(isinstance(ele, mstype.tensor_type) for ele in types)
|
||||
basic_cnt = sum(isinstance(ele, (mstype.Int, mstype.Ellipsis_, mstype.Slice)) for ele in types)
|
||||
basic_cnt = sum(isinstance(
|
||||
ele, (mstype.Int, mstype.Ellipsis_, mstype.Slice)) for ele in types)
|
||||
if tensor_cnt == len(types):
|
||||
return ALL_TENSOR
|
||||
if basic_cnt == len(types):
|
||||
|
@ -380,7 +384,7 @@ def check_value_elements(data_dtype, types):
|
|||
@constexpr
|
||||
def get_index_tensor_dtype(dtype):
|
||||
"""Check a tuple of tensor data type."""
|
||||
if dtype == mstype.int32:
|
||||
if dtype in mstype.int_type:
|
||||
return INT_
|
||||
if dtype == mstype.bool_:
|
||||
return BOOL_
|
||||
|
@ -428,7 +432,8 @@ def generate_broadcast_shape(shapes, op_name):
|
|||
for i, shape in enumerate(shapes):
|
||||
logger.debug(f"Broadcasts the {i}th tensor, the shape is {shape}.")
|
||||
try:
|
||||
broadcast_shape = op_utils.get_broadcast_shape(broadcast_shape, shape, op_name)
|
||||
broadcast_shape = op_utils.get_broadcast_shape(
|
||||
broadcast_shape, shape, op_name)
|
||||
except ValueError as ex:
|
||||
raise IndexError(ex)
|
||||
return tuple(broadcast_shape)
|
||||
|
@ -473,9 +478,10 @@ def compute_new_shape(origin_shape, indexes_shapes_info):
|
|||
@constexpr
|
||||
def check_sequence_index_type(sequence_index, op_name):
|
||||
"""check if the item's type of list_index is bool or int"""
|
||||
if not all([isinstance(index, (int, bool)) for index in sequence_index]):
|
||||
raise IndexError(f"In the {op_name} operation, only support 'integer' or 'boolean' array(list/tuple), "
|
||||
f"but got {type(index)} in array")
|
||||
for index in sequence_index:
|
||||
if not isinstance(index, int):
|
||||
raise IndexError(f"In the {op_name} operation, only support 'inter' or 'boolean' array(list/tuple), "
|
||||
f"but got {type(index)} in array.")
|
||||
|
||||
|
||||
@constexpr
|
||||
|
@ -497,13 +503,17 @@ def check_and_transform_int_index(index, shape, op_name):
|
|||
@constexpr
|
||||
def transform_sequence_index(sequence_index, shape, op_name):
|
||||
"""transform list or tuple with integer and boolean to tuple with integer index"""
|
||||
bool_count = len(list(filter(lambda index: isinstance(index, bool), sequence_index)))
|
||||
int_count = len(list(filter(lambda index: isinstance(index, int), sequence_index)))-bool_count
|
||||
bool_count = len(
|
||||
list(filter(lambda index: isinstance(index, bool), sequence_index)))
|
||||
int_count = len(
|
||||
list(filter(lambda index: isinstance(index, int), sequence_index)))-bool_count
|
||||
if int_count == 0:
|
||||
if bool_count == shape:
|
||||
list_index = list(filter(lambda i: sequence_index[i], range(bool_count)))
|
||||
list_index = list(
|
||||
filter(lambda i: sequence_index[i], range(bool_count)))
|
||||
else:
|
||||
raise IndexError("The boolean array should have the same length with the corresponding dimensiton")
|
||||
raise IndexError(
|
||||
"The boolean array should have the same length with the corresponding dimensiton")
|
||||
else:
|
||||
list_index = [int(index) for index in sequence_index]
|
||||
for i, index in enumerate(list_index):
|
||||
|
|
Loading…
Reference in New Issue