add the support for empty list and tuple index contain shape '0'

This commit is contained in:
yepei6 2021-02-04 17:26:49 +08:00
parent e09aaafdaf
commit bf6f0e1932
1 changed files with 76 additions and 11 deletions

View File

@ -129,7 +129,7 @@ def _transform_ellipsis_to_slice(data, tuple_index, op_name):
return tuple_index_new
def _expand_data_dims(data, tuple_index, op_name):
def _expand_data_dims(data, tuple_index):
"""expand the data's dim with 'None' and 'Boolean' in tuple_index"""
indexes_types = hyper_map(F.typeof, tuple_index)
expand_positions, tuple_index_new = (), ()
@ -203,8 +203,14 @@ def tensor_index_by_list(data, list_index):
indexes_types = hyper_map(F.typeof, list_index)
if const_utils.judge_indexes_types(indexes_types, mstype.int_type + (mstype.bool_,)):
sub_tuple_index = const_utils.transform_sequence_index(list_index, data_shape[0], const_utils.TENSOR_GETITEM)
if not sub_tuple_index:
data_rank = len(data_shape)
if data_rank == 1:
return const_utils.make_tensor([], data.dtype, ())
return const_utils.make_tensor([], data.dtype, data_shape[1:])
tensor_index = const_utils.make_tensor(sub_tuple_index, mstype.int64)
return F.gather(data, tensor_index, 0)
tuple_index_new = ()
for index in list_index:
tuple_index_new += (index,)
@ -219,7 +225,7 @@ def tensor_index_by_tuple(data, tuple_index):
op_name = const_utils.TENSOR_GETITEM
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
data, tuple_index = _expand_data_dims(data, tuple_index, op_name)
data, tuple_index = _expand_data_dims(data, tuple_index)
data_shape = F.shape(data)
data_rank = len(data_shape)
@ -228,6 +234,7 @@ def tensor_index_by_tuple(data, tuple_index):
indexes_types = hyper_map(F.typeof, tuple_index)
contain_type = const_utils.tuple_index_type_cnt(indexes_types, op_name)
if contain_type == const_utils.ALL_TENSOR:
return _tensor_getitem_by_tuple_of_tensor(data, tuple_index, op_name)
if contain_type == const_utils.ALL_BASIC:
@ -245,7 +252,9 @@ def _tensor_getitem_by_tuple_of_tensor(data, tuple_index, op_name):
tensor_index_shape = hyper_map(F.shape, tuple_index)
broadcast_shape = const_utils.generate_broadcast_shape(tensor_index_shape, op_name)
if 0 in broadcast_shape:
res_shape = broadcast_shape + data_shape[tuple_index_len:]
res_shape = broadcast_shape
if tuple_index_len < len(data_shape):
res_shape += data_shape[tuple_index_len:]
res = const_utils.make_tensor([], data.dtype, res_shape)
return res
@ -268,12 +277,68 @@ def _tensor_getitem_by_tuple_slice(data, tuple_index):
def _tensor_getitem_by_tuple(data, tuple_index, op_name):
"""Tensor getitem by a tuple of mixed tensor."""
indices = _generate_indices_from_tuple(data, tuple_index, op_name)
data_shape = F.shape(data)
data_rank = len(data_shape)
tuple_index_len = len(tuple_index)
tensor_indexes, slice_indexes = [], []
indexes_types = hyper_map(F.typeof, tuple_index)
slice_positions, _, _, int_positions, _, tensor_positions, sequence_positions = \
const_utils.get_pos_of_indexes_types(indexes_types, op_name)
tuple_index_new = ()
for i, (index, dim_size) in enumerate(zip(tuple_index, data_shape)):
if i in int_positions:
int_index = const_utils.check_and_transform_int_index(index, dim_size, op_name)
tensor_index = F.scalar_to_tensor(int_index, mstype.int64)
tuple_index_new += (tensor_index,)
tensor_indexes.append(tensor_index)
tensor_positions.append(i)
elif i in sequence_positions:
sequence_index = const_utils.transform_sequence_index(index, dim_size, op_name)
tensor_index = const_utils.make_tensor(sequence_index)
tensor_index = F.cast(tensor_index, mstype.int64)
tuple_index_new += (tensor_index,)
tensor_indexes.append(tensor_index)
tensor_positions.append(i)
elif i in tensor_positions:
const_utils.check_index_type_valid(F.dtype(index), mstype.int_type, op_name)
tensor_index = F.cast(index, mstype.int64)
tuple_index_new += (tensor_index,)
tensor_indexes.append(tensor_index)
elif i in slice_positions:
slice_indexes.append(index)
tuple_index_new += (index,)
tensor_indexes_shapes = hyper_map(F.shape, tensor_indexes)
tensor_indexes_dtypes = hyper_map(F.dtype, tensor_indexes)
indexes_types = hyper_map(F.typeof, tuple_index_new)
broadcast_shape, final_shape, indexes_shapes_info = const_utils.generate_index_info_from_tuple_of_mixed_tensors(
data_shape, indexes_types, tensor_indexes_shapes, tensor_indexes_dtypes, slice_indexes, op_name)
if 0 in final_shape:
if tuple_index_len < data_rank:
final_shape = final_shape + data_shape[tuple_index_len:]
return const_utils.make_tensor([], data.dtype, final_shape)
slice_number = 0
final_index_tensors = []
index_tensor_new_shape = const_utils.compute_new_shape(broadcast_shape, indexes_shapes_info)
for i in range(tuple_index_len):
if i in tensor_positions:
transform_tensor = _transform_indexing_tensor(broadcast_shape, final_shape, index_tensor_new_shape,
tuple_index_new[i])
final_index_tensors.append(transform_tensor)
if i in slice_positions:
slice_tensor = const_utils.convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_name)
final_index_tensors.append(slice_tensor)
slice_number += 1
indices = pack(final_index_tensors)
result = F.gather_nd(data, indices)
return result
def _generate_indices_from_tuple_of_tensor(data, tuple_index, op_name):
def _generate_indices_from_tuple_of_tensor(tuple_index, op_name):
"""Generate an indices tensor from a tuple of tensor."""
indices = None
indexes_types = hyper_map(F.dtype, tuple_index)
@ -510,13 +575,13 @@ def tensor_setitem_by_tuple_with_number(data, tuple_index, value):
return data
op_name = const_utils.TENSOR_GETITEM
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
data, tuple_index = _expand_data_dims(data, tuple_index, op_name)
data, tuple_index = _expand_data_dims(data, tuple_index)
indexes_types = hyper_map(F.typeof, tuple_index)
contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_SETITEM)
if contain_type == const_utils.ALL_TENSOR:
indices = _generate_indices_from_tuple_of_tensor(data, tuple_index, const_utils.TENSOR_SETITEM)
indices = _generate_indices_from_tuple_of_tensor(tuple_index, const_utils.TENSOR_SETITEM)
else:
int_cnt = const_utils.tuple_index_int_cnt(indexes_types, const_utils.TENSOR_SETITEM)
if int_cnt == const_utils.ALL_INT:
@ -572,13 +637,13 @@ def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
return data
op_name = const_utils.TENSOR_GETITEM
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
data, tuple_index = _expand_data_dims(data, tuple_index, op_name)
data, tuple_index = _expand_data_dims(data, tuple_index)
indexes_types = hyper_map(F.typeof, tuple_index)
contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_SETITEM)
if contain_type == const_utils.ALL_TENSOR:
indices = _generate_indices_from_tuple_of_tensor(data, tuple_index, const_utils.TENSOR_SETITEM)
indices = _generate_indices_from_tuple_of_tensor(tuple_index, const_utils.TENSOR_SETITEM)
else:
int_cnt = const_utils.tuple_index_int_cnt(indexes_types, const_utils.TENSOR_SETITEM)
if int_cnt == const_utils.ALL_INT:
@ -600,13 +665,13 @@ def tensor_setitem_by_tuple_with_tuple(data, tuple_index, value):
return data
op_name = const_utils.TENSOR_GETITEM
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
data, tuple_index = _expand_data_dims(data, tuple_index, op_name)
data, tuple_index = _expand_data_dims(data, tuple_index)
indexes_types = hyper_map(F.typeof, tuple_index)
contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_SETITEM)
if contain_type == const_utils.ALL_TENSOR:
indices = _generate_indices_from_tuple_of_tensor(data, tuple_index, const_utils.TENSOR_SETITEM)
indices = _generate_indices_from_tuple_of_tensor(tuple_index, const_utils.TENSOR_SETITEM)
else:
int_cnt = const_utils.tuple_index_int_cnt(indexes_types, const_utils.TENSOR_SETITEM)
if int_cnt == const_utils.ALL_INT: