forked from mindspore-Ecosystem/mindspore
!11316 master getitem debug
From: @yepei6 Reviewed-by: @zh_qh,@kingxian Signed-off-by: @kingxian
This commit is contained in:
commit
fbcfbed66d
|
@ -129,43 +129,24 @@ def _transform_ellipsis_to_slice(data, tuple_index, op_name):
|
||||||
return tuple_index_new
|
return tuple_index_new
|
||||||
|
|
||||||
|
|
||||||
def _expand_data_dims_with_none(data, tuple_index, op_name):
|
def _expand_data_dims(data, tuple_index, op_name):
|
||||||
"""expand the data's dim with 'None' in tuple_index"""
|
"""expand the data's dim with 'None' and 'Boolean' in tuple_index"""
|
||||||
indexes_types = hyper_map(F.typeof, tuple_index)
|
indexes_types = hyper_map(F.typeof, tuple_index)
|
||||||
none_positions, tuple_index_without_none = (), ()
|
expand_positions, tuple_index_new = (), ()
|
||||||
for i, (index, index_type) in enumerate(zip(tuple_index, indexes_types)):
|
for i, (index, index_type) in enumerate(zip(tuple_index, indexes_types)):
|
||||||
none_type_tag = const_utils.judge_index_type(index_type, mstype.type_none)
|
if const_utils.judge_index_type(index_type, mstype.type_none):
|
||||||
tuple_index_without_none += (const_utils.make_empty_slice(),) if none_type_tag else(index,)
|
tuple_index_new += (const_utils.make_empty_slice(),)
|
||||||
none_positions += (i,) if none_type_tag else ()
|
expand_positions += (i,)
|
||||||
for dim in none_positions:
|
elif const_utils.judge_index_type(index_type, mstype.bool_):
|
||||||
data = F.expand_dims(data, dim)
|
tuple_index_new += (const_utils.make_tensor([0] if index else[], mstype.int64),)
|
||||||
return data, tuple_index_without_none
|
expand_positions += (i,)
|
||||||
|
|
||||||
|
|
||||||
def _expand_data_dims_with_bool(data, tuple_index, op_name):
|
|
||||||
"""expand the data's dim with 'True/False' in tuple_index"""
|
|
||||||
indexes_types = hyper_map(F.typeof, tuple_index)
|
|
||||||
bool_positions, tuple_index_without_bool = (), ()
|
|
||||||
|
|
||||||
for i, (index, index_type) in enumerate(zip(tuple_index, indexes_types)):
|
|
||||||
bool_type_tag = const_utils.judge_index_type(index_type, mstype.bool_)
|
|
||||||
if bool_type_tag:
|
|
||||||
if index:
|
|
||||||
tuple_index_without_bool += (const_utils.make_tensor([0], mstype.int64),)
|
|
||||||
else:
|
|
||||||
# todo wait to complete the operations' support for zero dim-size, then could make 0 length tensor.
|
|
||||||
# to replace the 'False'
|
|
||||||
|
|
||||||
return const_utils.raise_index_error("When tensor is indexed by a tuple which contains bool object, "
|
|
||||||
"the value only support 'True'.")
|
|
||||||
else:
|
else:
|
||||||
tuple_index_without_bool += (index,)
|
tuple_index_new += (index,)
|
||||||
bool_positions += (i,) if bool_type_tag else ()
|
|
||||||
|
|
||||||
for dim in bool_positions:
|
for dim in expand_positions:
|
||||||
data = F.expand_dims(data, dim)
|
data = F.expand_dims(data, dim)
|
||||||
|
|
||||||
return data, tuple_index_without_bool
|
return data, tuple_index_new
|
||||||
|
|
||||||
|
|
||||||
def tensor_index_by_slice(data, slice_index):
|
def tensor_index_by_slice(data, slice_index):
|
||||||
|
@ -219,22 +200,22 @@ def tensor_index_by_tensor(data, tensor_index):
|
||||||
def tensor_index_by_list(data, list_index):
|
def tensor_index_by_list(data, list_index):
|
||||||
"""Tensor getitem by list of int and bool"""
|
"""Tensor getitem by list of int and bool"""
|
||||||
data_shape = F.shape(data)
|
data_shape = F.shape(data)
|
||||||
const_utils.check_sequence_index_type(list_index, const_utils.TENSOR_GETITEM)
|
indexes_types = hyper_map(F.typeof, list_index)
|
||||||
sub_tuple_index = const_utils.transform_sequence_index(list_index, data_shape[0], const_utils.TENSOR_GETITEM)
|
if const_utils.judge_indexes_types(indexes_types, mstype.int_type + (mstype.bool_,)):
|
||||||
tensor_index = F.tuple_to_array(sub_tuple_index)
|
sub_tuple_index = const_utils.transform_sequence_index(list_index, data_shape[0], const_utils.TENSOR_GETITEM)
|
||||||
tensor_index = F.cast(tensor_index, mstype.int64)
|
tensor_index = const_utils.make_tensor(sub_tuple_index, mstype.int64)
|
||||||
return F.gather(data, tensor_index, 0)
|
return F.gather(data, tensor_index, 0)
|
||||||
|
tuple_index_new = ()
|
||||||
|
for index in list_index:
|
||||||
|
tuple_index_new += (index,)
|
||||||
|
return tensor_index_by_tuple(data, tuple_index_new)
|
||||||
|
|
||||||
|
|
||||||
def tensor_index_by_tuple(data, tuple_index):
|
def tensor_index_by_tuple(data, tuple_index):
|
||||||
"""Tensor getitem by tuple of various types with None"""
|
"""Tensor getitem by tuple of various types with None"""
|
||||||
op_name = const_utils.TENSOR_GETITEM
|
op_name = const_utils.TENSOR_GETITEM
|
||||||
if len(tuple_index) == 1:
|
|
||||||
return data[tuple_index[0]]
|
|
||||||
|
|
||||||
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
|
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
|
||||||
data, tuple_index = _expand_data_dims_with_none(data, tuple_index, op_name)
|
data, tuple_index = _expand_data_dims(data, tuple_index, op_name)
|
||||||
data, tuple_index = _expand_data_dims_with_bool(data, tuple_index, op_name)
|
|
||||||
|
|
||||||
indexes_types = hyper_map(F.typeof, tuple_index)
|
indexes_types = hyper_map(F.typeof, tuple_index)
|
||||||
contain_type = const_utils.tuple_index_type_cnt(indexes_types, op_name)
|
contain_type = const_utils.tuple_index_type_cnt(indexes_types, op_name)
|
||||||
|
@ -502,7 +483,7 @@ def tensor_setitem_by_tuple_with_number(data, tuple_index, value):
|
||||||
return data
|
return data
|
||||||
op_name = const_utils.TENSOR_GETITEM
|
op_name = const_utils.TENSOR_GETITEM
|
||||||
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
|
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
|
||||||
data, tuple_index = _expand_data_dims_with_none(data, tuple_index, op_name)
|
data, tuple_index = _expand_data_dims(data, tuple_index, op_name)
|
||||||
|
|
||||||
indexes_types = hyper_map(F.typeof, tuple_index)
|
indexes_types = hyper_map(F.typeof, tuple_index)
|
||||||
contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_SETITEM)
|
contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_SETITEM)
|
||||||
|
@ -564,7 +545,7 @@ def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
|
||||||
return data
|
return data
|
||||||
op_name = const_utils.TENSOR_GETITEM
|
op_name = const_utils.TENSOR_GETITEM
|
||||||
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
|
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
|
||||||
data, tuple_index = _expand_data_dims_with_none(data, tuple_index, op_name)
|
data, tuple_index = _expand_data_dims(data, tuple_index, op_name)
|
||||||
|
|
||||||
indexes_types = hyper_map(F.typeof, tuple_index)
|
indexes_types = hyper_map(F.typeof, tuple_index)
|
||||||
contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_SETITEM)
|
contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_SETITEM)
|
||||||
|
@ -592,7 +573,7 @@ def tensor_setitem_by_tuple_with_tuple(data, tuple_index, value):
|
||||||
return data
|
return data
|
||||||
op_name = const_utils.TENSOR_GETITEM
|
op_name = const_utils.TENSOR_GETITEM
|
||||||
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
|
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
|
||||||
data, tuple_index = _expand_data_dims_with_none(data, tuple_index, op_name)
|
data, tuple_index = _expand_data_dims(data, tuple_index, op_name)
|
||||||
|
|
||||||
indexes_types = hyper_map(F.typeof, tuple_index)
|
indexes_types = hyper_map(F.typeof, tuple_index)
|
||||||
contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_SETITEM)
|
contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_SETITEM)
|
||||||
|
|
|
@ -74,7 +74,7 @@ def make_empty_slice():
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def make_tensor(data, data_type, data_shape=None):
|
def make_tensor(data, data_type=mstype.int64, data_shape=None):
|
||||||
if data_shape:
|
if data_shape:
|
||||||
return Tensor(np.zeros(data_shape), data_type)
|
return Tensor(np.zeros(data_shape), data_type)
|
||||||
return Tensor(data, data_type)
|
return Tensor(data, data_type)
|
||||||
|
@ -158,6 +158,15 @@ def check_index_type_valid(dtype, target_type, op_name):
|
||||||
f"The '{op_name}' doesn't supoort {dtype}' and expecte to receive {target_type}.")
|
f"The '{op_name}' doesn't supoort {dtype}' and expecte to receive {target_type}.")
|
||||||
|
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def judge_indexes_types(dtypes, target_type):
|
||||||
|
"""Check a tuple of tensor data type."""
|
||||||
|
for dtype in dtypes:
|
||||||
|
if dtype != target_type and (isinstance(target_type, (list, tuple)) and dtype not in target_type):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def check_indexes_types_valid(dtypes, target_type, op_name):
|
def check_indexes_types_valid(dtypes, target_type, op_name):
|
||||||
"""Check a tuple of tensor data type."""
|
"""Check a tuple of tensor data type."""
|
||||||
|
@ -475,15 +484,6 @@ def compute_new_shape(origin_shape, indexes_shapes_info):
|
||||||
return tuple(new_shape)
|
return tuple(new_shape)
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
|
||||||
def check_sequence_index_type(sequence_index, op_name):
|
|
||||||
"""check if the item's type of list_index is bool or int"""
|
|
||||||
for index in sequence_index:
|
|
||||||
if not isinstance(index, int):
|
|
||||||
raise IndexError(f"In the {op_name} operation, only support 'inter' or 'boolean' array(list/tuple), "
|
|
||||||
f"but got {type(index)} in array.")
|
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def convert_int_to_slice(tuple_index):
|
def convert_int_to_slice(tuple_index):
|
||||||
tuple_index_new = tuple(slice(i, i+1, 1) for i in tuple_index)
|
tuple_index_new = tuple(slice(i, i+1, 1) for i in tuple_index)
|
||||||
|
@ -503,19 +503,16 @@ def check_and_transform_int_index(index, shape, op_name):
|
||||||
@constexpr
|
@constexpr
|
||||||
def transform_sequence_index(sequence_index, shape, op_name):
|
def transform_sequence_index(sequence_index, shape, op_name):
|
||||||
"""transform list or tuple with integer and boolean to tuple with integer index"""
|
"""transform list or tuple with integer and boolean to tuple with integer index"""
|
||||||
bool_count = len(
|
bool_count = len(list(filter(lambda index: isinstance(index, bool), sequence_index)))
|
||||||
list(filter(lambda index: isinstance(index, bool), sequence_index)))
|
int_count = len(list(filter(lambda index: isinstance(index, int), sequence_index)))-bool_count
|
||||||
int_count = len(
|
if int_count == 0 and bool_count != 0:
|
||||||
list(filter(lambda index: isinstance(index, int), sequence_index)))-bool_count
|
|
||||||
if int_count == 0:
|
|
||||||
if bool_count == shape:
|
if bool_count == shape:
|
||||||
list_index = list(
|
list_index = list(filter(lambda i: sequence_index[i], range(bool_count)))
|
||||||
filter(lambda i: sequence_index[i], range(bool_count)))
|
|
||||||
else:
|
else:
|
||||||
raise IndexError(
|
raise IndexError("The boolean array should have the same length with the corresponding dimensiton")
|
||||||
"The boolean array should have the same length with the corresponding dimensiton")
|
|
||||||
else:
|
else:
|
||||||
list_index = [int(index) for index in sequence_index]
|
list_index = [int(index) for index in sequence_index]
|
||||||
|
|
||||||
for i, index in enumerate(list_index):
|
for i, index in enumerate(list_index):
|
||||||
list_index[i] = check_and_transform_int_index(index, shape, op_name)
|
list_index[i] = check_and_transform_int_index(index, shape, op_name)
|
||||||
sub_tuple_index = tuple(list_index)
|
sub_tuple_index = tuple(list_index)
|
||||||
|
|
Loading…
Reference in New Issue