forked from mindspore-Ecosystem/mindspore
add the support for empty list and tuple index contain shape '0'
This commit is contained in:
parent
e09aaafdaf
commit
bf6f0e1932
|
@ -129,7 +129,7 @@ def _transform_ellipsis_to_slice(data, tuple_index, op_name):
|
|||
return tuple_index_new
|
||||
|
||||
|
||||
def _expand_data_dims(data, tuple_index, op_name):
|
||||
def _expand_data_dims(data, tuple_index):
|
||||
"""expand the data's dim with 'None' and 'Boolean' in tuple_index"""
|
||||
indexes_types = hyper_map(F.typeof, tuple_index)
|
||||
expand_positions, tuple_index_new = (), ()
|
||||
|
@ -203,8 +203,14 @@ def tensor_index_by_list(data, list_index):
|
|||
indexes_types = hyper_map(F.typeof, list_index)
|
||||
if const_utils.judge_indexes_types(indexes_types, mstype.int_type + (mstype.bool_,)):
|
||||
sub_tuple_index = const_utils.transform_sequence_index(list_index, data_shape[0], const_utils.TENSOR_GETITEM)
|
||||
if not sub_tuple_index:
|
||||
data_rank = len(data_shape)
|
||||
if data_rank == 1:
|
||||
return const_utils.make_tensor([], data.dtype, ())
|
||||
return const_utils.make_tensor([], data.dtype, data_shape[1:])
|
||||
tensor_index = const_utils.make_tensor(sub_tuple_index, mstype.int64)
|
||||
return F.gather(data, tensor_index, 0)
|
||||
|
||||
tuple_index_new = ()
|
||||
for index in list_index:
|
||||
tuple_index_new += (index,)
|
||||
|
@ -219,7 +225,7 @@ def tensor_index_by_tuple(data, tuple_index):
|
|||
|
||||
op_name = const_utils.TENSOR_GETITEM
|
||||
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
|
||||
data, tuple_index = _expand_data_dims(data, tuple_index, op_name)
|
||||
data, tuple_index = _expand_data_dims(data, tuple_index)
|
||||
|
||||
data_shape = F.shape(data)
|
||||
data_rank = len(data_shape)
|
||||
|
@ -228,6 +234,7 @@ def tensor_index_by_tuple(data, tuple_index):
|
|||
|
||||
indexes_types = hyper_map(F.typeof, tuple_index)
|
||||
contain_type = const_utils.tuple_index_type_cnt(indexes_types, op_name)
|
||||
|
||||
if contain_type == const_utils.ALL_TENSOR:
|
||||
return _tensor_getitem_by_tuple_of_tensor(data, tuple_index, op_name)
|
||||
if contain_type == const_utils.ALL_BASIC:
|
||||
|
@ -245,7 +252,9 @@ def _tensor_getitem_by_tuple_of_tensor(data, tuple_index, op_name):
|
|||
tensor_index_shape = hyper_map(F.shape, tuple_index)
|
||||
broadcast_shape = const_utils.generate_broadcast_shape(tensor_index_shape, op_name)
|
||||
if 0 in broadcast_shape:
|
||||
res_shape = broadcast_shape + data_shape[tuple_index_len:]
|
||||
res_shape = broadcast_shape
|
||||
if tuple_index_len < len(data_shape):
|
||||
res_shape += data_shape[tuple_index_len:]
|
||||
res = const_utils.make_tensor([], data.dtype, res_shape)
|
||||
return res
|
||||
|
||||
|
@ -268,12 +277,68 @@ def _tensor_getitem_by_tuple_slice(data, tuple_index):
|
|||
|
||||
def _tensor_getitem_by_tuple(data, tuple_index, op_name):
|
||||
"""Tensor getitem by a tuple of mixed tensor."""
|
||||
indices = _generate_indices_from_tuple(data, tuple_index, op_name)
|
||||
data_shape = F.shape(data)
|
||||
data_rank = len(data_shape)
|
||||
tuple_index_len = len(tuple_index)
|
||||
tensor_indexes, slice_indexes = [], []
|
||||
indexes_types = hyper_map(F.typeof, tuple_index)
|
||||
slice_positions, _, _, int_positions, _, tensor_positions, sequence_positions = \
|
||||
const_utils.get_pos_of_indexes_types(indexes_types, op_name)
|
||||
tuple_index_new = ()
|
||||
|
||||
for i, (index, dim_size) in enumerate(zip(tuple_index, data_shape)):
|
||||
if i in int_positions:
|
||||
int_index = const_utils.check_and_transform_int_index(index, dim_size, op_name)
|
||||
tensor_index = F.scalar_to_tensor(int_index, mstype.int64)
|
||||
tuple_index_new += (tensor_index,)
|
||||
tensor_indexes.append(tensor_index)
|
||||
tensor_positions.append(i)
|
||||
elif i in sequence_positions:
|
||||
sequence_index = const_utils.transform_sequence_index(index, dim_size, op_name)
|
||||
tensor_index = const_utils.make_tensor(sequence_index)
|
||||
tensor_index = F.cast(tensor_index, mstype.int64)
|
||||
tuple_index_new += (tensor_index,)
|
||||
tensor_indexes.append(tensor_index)
|
||||
tensor_positions.append(i)
|
||||
elif i in tensor_positions:
|
||||
const_utils.check_index_type_valid(F.dtype(index), mstype.int_type, op_name)
|
||||
tensor_index = F.cast(index, mstype.int64)
|
||||
tuple_index_new += (tensor_index,)
|
||||
tensor_indexes.append(tensor_index)
|
||||
elif i in slice_positions:
|
||||
slice_indexes.append(index)
|
||||
tuple_index_new += (index,)
|
||||
|
||||
tensor_indexes_shapes = hyper_map(F.shape, tensor_indexes)
|
||||
tensor_indexes_dtypes = hyper_map(F.dtype, tensor_indexes)
|
||||
indexes_types = hyper_map(F.typeof, tuple_index_new)
|
||||
broadcast_shape, final_shape, indexes_shapes_info = const_utils.generate_index_info_from_tuple_of_mixed_tensors(
|
||||
data_shape, indexes_types, tensor_indexes_shapes, tensor_indexes_dtypes, slice_indexes, op_name)
|
||||
|
||||
if 0 in final_shape:
|
||||
if tuple_index_len < data_rank:
|
||||
final_shape = final_shape + data_shape[tuple_index_len:]
|
||||
return const_utils.make_tensor([], data.dtype, final_shape)
|
||||
|
||||
slice_number = 0
|
||||
final_index_tensors = []
|
||||
index_tensor_new_shape = const_utils.compute_new_shape(broadcast_shape, indexes_shapes_info)
|
||||
for i in range(tuple_index_len):
|
||||
if i in tensor_positions:
|
||||
transform_tensor = _transform_indexing_tensor(broadcast_shape, final_shape, index_tensor_new_shape,
|
||||
tuple_index_new[i])
|
||||
final_index_tensors.append(transform_tensor)
|
||||
if i in slice_positions:
|
||||
slice_tensor = const_utils.convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_name)
|
||||
final_index_tensors.append(slice_tensor)
|
||||
slice_number += 1
|
||||
|
||||
indices = pack(final_index_tensors)
|
||||
result = F.gather_nd(data, indices)
|
||||
return result
|
||||
|
||||
|
||||
def _generate_indices_from_tuple_of_tensor(data, tuple_index, op_name):
|
||||
def _generate_indices_from_tuple_of_tensor(tuple_index, op_name):
|
||||
"""Generate an indices tensor from a tuple of tensor."""
|
||||
indices = None
|
||||
indexes_types = hyper_map(F.dtype, tuple_index)
|
||||
|
@ -510,13 +575,13 @@ def tensor_setitem_by_tuple_with_number(data, tuple_index, value):
|
|||
return data
|
||||
op_name = const_utils.TENSOR_GETITEM
|
||||
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
|
||||
data, tuple_index = _expand_data_dims(data, tuple_index, op_name)
|
||||
data, tuple_index = _expand_data_dims(data, tuple_index)
|
||||
|
||||
indexes_types = hyper_map(F.typeof, tuple_index)
|
||||
contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_SETITEM)
|
||||
|
||||
if contain_type == const_utils.ALL_TENSOR:
|
||||
indices = _generate_indices_from_tuple_of_tensor(data, tuple_index, const_utils.TENSOR_SETITEM)
|
||||
indices = _generate_indices_from_tuple_of_tensor(tuple_index, const_utils.TENSOR_SETITEM)
|
||||
else:
|
||||
int_cnt = const_utils.tuple_index_int_cnt(indexes_types, const_utils.TENSOR_SETITEM)
|
||||
if int_cnt == const_utils.ALL_INT:
|
||||
|
@ -572,13 +637,13 @@ def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
|
|||
return data
|
||||
op_name = const_utils.TENSOR_GETITEM
|
||||
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
|
||||
data, tuple_index = _expand_data_dims(data, tuple_index, op_name)
|
||||
data, tuple_index = _expand_data_dims(data, tuple_index)
|
||||
|
||||
indexes_types = hyper_map(F.typeof, tuple_index)
|
||||
contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_SETITEM)
|
||||
|
||||
if contain_type == const_utils.ALL_TENSOR:
|
||||
indices = _generate_indices_from_tuple_of_tensor(data, tuple_index, const_utils.TENSOR_SETITEM)
|
||||
indices = _generate_indices_from_tuple_of_tensor(tuple_index, const_utils.TENSOR_SETITEM)
|
||||
else:
|
||||
int_cnt = const_utils.tuple_index_int_cnt(indexes_types, const_utils.TENSOR_SETITEM)
|
||||
if int_cnt == const_utils.ALL_INT:
|
||||
|
@ -600,13 +665,13 @@ def tensor_setitem_by_tuple_with_tuple(data, tuple_index, value):
|
|||
return data
|
||||
op_name = const_utils.TENSOR_GETITEM
|
||||
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
|
||||
data, tuple_index = _expand_data_dims(data, tuple_index, op_name)
|
||||
data, tuple_index = _expand_data_dims(data, tuple_index)
|
||||
|
||||
indexes_types = hyper_map(F.typeof, tuple_index)
|
||||
contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_SETITEM)
|
||||
|
||||
if contain_type == const_utils.ALL_TENSOR:
|
||||
indices = _generate_indices_from_tuple_of_tensor(data, tuple_index, const_utils.TENSOR_SETITEM)
|
||||
indices = _generate_indices_from_tuple_of_tensor(tuple_index, const_utils.TENSOR_SETITEM)
|
||||
else:
|
||||
int_cnt = const_utils.tuple_index_int_cnt(indexes_types, const_utils.TENSOR_SETITEM)
|
||||
if int_cnt == const_utils.ALL_INT:
|
||||
|
|
Loading…
Reference in New Issue