fancy index getitem
This commit is contained in:
parent
ae4e5b93eb
commit
5c9982729d
|
@ -57,6 +57,68 @@ def _generate_indices_from_tuple_of_tensor(data, tuple_index, op_name):
|
|||
return indices
|
||||
|
||||
|
||||
def _generate_indices_from_tuple(data, tuple_index, op_name):
|
||||
"""Generate an indices tensor from a tuple that contains slice, int, ellipsis, tensor."""
|
||||
data_shape = F.shape(data)
|
||||
indexes_types = hyper_map(F.typeof, tuple_index)
|
||||
int_positions, sequence_positions = const_utils.get_pos_of_int_sequence(indexes_types)
|
||||
tuple_index_new = ()
|
||||
tuple_len = len(tuple_index)
|
||||
for i in range(tuple_len):
|
||||
index = tuple_index[i]
|
||||
shape = data_shape[i]
|
||||
if i in int_positions:
|
||||
int_index = const_utils.check_and_transform_int_index(index, shape, op_name)
|
||||
tensor_index = F.scalar_to_tensor(int_index, mstype.int32)
|
||||
tuple_index_new += (tensor_index,)
|
||||
elif i in sequence_positions:
|
||||
sequence_index = const_utils.transform_sequence_index(index, shape, op_name)
|
||||
tensor_index = F.tuple_to_array(sequence_index)
|
||||
tuple_index_new += (tensor_index,)
|
||||
else:
|
||||
tuple_index_new += (index,)
|
||||
indexes_types = hyper_map(F.typeof, tuple_index_new)
|
||||
tensor_positions, slice_positions, ellipsis_position = \
|
||||
const_utils.separate_mixed_tensors_index(indexes_types, op_name)
|
||||
tensor_indexes = []
|
||||
slice_indexes = []
|
||||
for i in tensor_positions:
|
||||
tensor_indexes.append(tuple_index_new[i])
|
||||
for j in slice_positions:
|
||||
slice_indexes.append(tuple_index_new[j])
|
||||
tensor_indexes_shapes = hyper_map(F.shape, tensor_indexes)
|
||||
tensor_indexes_dtypes = hyper_map(F.dtype, tensor_indexes)
|
||||
broadcast_shape, final_shape, indexes_shapes_info, ellipsis_occupied_dims = \
|
||||
const_utils.generate_index_info_from_tuple_of_mixed_tensors(data_shape,
|
||||
indexes_types,
|
||||
tensor_indexes_shapes,
|
||||
tensor_indexes_dtypes,
|
||||
slice_indexes,
|
||||
op_name)
|
||||
|
||||
slice_number = 0
|
||||
final_index_tensors = []
|
||||
tuple_index_size = len(tuple_index_new)
|
||||
index_tensor_new_shape = const_utils.compute_new_shape(broadcast_shape, indexes_shapes_info)
|
||||
for i in range(tuple_index_size):
|
||||
if i in tensor_positions:
|
||||
transform_tensor = _transform_indexing_tensor(
|
||||
broadcast_shape, final_shape, index_tensor_new_shape, tuple_index_new[i])
|
||||
final_index_tensors.append(transform_tensor)
|
||||
if i in slice_positions:
|
||||
slice_tensor = const_utils.convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_name)
|
||||
final_index_tensors.append(slice_tensor)
|
||||
slice_number += 1
|
||||
if i == ellipsis_position:
|
||||
ellipsis_tensors = const_utils.convert_ellipsis_to_tensors(
|
||||
slice_number, ellipsis_occupied_dims, final_shape, indexes_shapes_info, op_name)
|
||||
for ele in ellipsis_tensors:
|
||||
final_index_tensors.append(ele)
|
||||
slice_number += ellipsis_occupied_dims
|
||||
indices = pack(final_index_tensors)
|
||||
return indices
|
||||
|
||||
|
||||
def _generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name):
|
||||
"""Generate an indices tensor from a tuple that contains slice, int, ellipsis, tensor."""
|
||||
data_shape = F.shape(data)
|
||||
|
@ -160,6 +222,8 @@ def _tensor_getitem(self, index):
|
|||
return tensor_index_by_tensor(self, index)
|
||||
if isinstance(index, tuple):
|
||||
return tensor_index_by_tuple(self, index)
|
||||
if isinstance(index, list):
|
||||
return tensor_index_by_list(self, index)
|
||||
# bool type should be judged before int
|
||||
if isinstance(index, bool):
|
||||
return _tensor_index_by_bool(self, index)
|
||||
|
@ -187,6 +251,13 @@ def _tensor_getitem_by_tuple_of_tensor(data, tuple_index):
|
|||
return result
|
||||
|
||||
|
||||
def _tensor_getitem_by_tuple(data, tuple_index):
|
||||
"""Tensor getitem by a tuple of mixed tensor."""
|
||||
indices = _generate_indices_from_tuple(data, tuple_index, const_utils.TENSOR_GETITEM)
|
||||
result = F.gather_nd(data, indices)
|
||||
return result
|
||||
|
||||
|
||||
def _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index):
|
||||
"""Tensor getitem by a tuple of mixed tensor."""
|
||||
indices = _generate_indices_from_tuple_of_mixed_tensors(data,
|
||||
|
@ -273,12 +344,12 @@ def tensor_index_by_tuple(data, tuple_index):
|
|||
if len(tuple_index) == 1:
|
||||
return data[tuple_index_without_none[0]]
|
||||
indexes_types = hyper_map(F.typeof, tuple_index_without_none)
|
||||
tensor_cnt = const_utils.tuple_index_tensor_cnt(indexes_types, const_utils.TENSOR_GETITEM)
|
||||
if tensor_cnt == const_utils.NO_TENSOR:
|
||||
return _tensor_index_by_tuple_slice(data, tuple_index_without_none)
|
||||
if tensor_cnt == const_utils.ALL_TENSOR:
|
||||
return _tensor_getitem_by_tuple_of_tensor(data, tuple_index_without_none)
|
||||
return _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index_without_none)
|
||||
contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_GETITEM)
|
||||
if contain_type == const_utils.ALL_TENSOR:
|
||||
return _tensor_getitem_by_tuple_of_tensor(data, tuple_index)
|
||||
if contain_type == const_utils.ALL_BASIC:
|
||||
return _tensor_index_by_tuple_slice(data, tuple_index)
|
||||
return _tensor_getitem_by_tuple(data, tuple_index_without_none)
|
||||
|
||||
|
||||
def _tensor_setitem(self, index, value):
|
||||
|
|
|
@ -31,6 +31,8 @@ ALL_SCALAR = 3
|
|||
ALL_INT = 4
|
||||
NO_INT = 5
|
||||
CONTAIN_INT = 6
|
||||
ALL_BASIC = 7
|
||||
MIXED = 8
|
||||
|
||||
INT_ = 0
|
||||
BOOL_ = 1
|
||||
|
@ -307,6 +309,18 @@ def tuple_index_int_cnt(types, op_name):
|
|||
return ALL_INT if int_cnt == len(types) else NO_INT if int_cnt == 0 else CONTAIN_INT
|
||||
|
||||
|
||||
@constexpr
|
||||
def tuple_index_type_cnt(types, op_name):
|
||||
"""count the tensor type of types which contains the tuple elements' type."""
|
||||
tensor_cnt = sum(isinstance(ele, mstype.tensor_type) for ele in types)
|
||||
basic_cnt = sum(isinstance(ele, (mstype.Int, mstype.ellipsis_type, mstype.slice_type)) for ele in types)
|
||||
if tensor_cnt == len(types):
|
||||
return ALL_TENSOR
|
||||
if basic_cnt == len(types):
|
||||
return ALL_BASIC
|
||||
return MIXED
|
||||
|
||||
|
||||
@constexpr
|
||||
def check_value_elements(data_dtype, types):
|
||||
"""Judges the type of all elements of the tuple."""
|
||||
|
@ -501,6 +515,34 @@ def convert_ellipsis_to_tensors(slice_number,
|
|||
return tensor_list
|
||||
|
||||
|
||||
@constexpr
|
||||
def check_and_transform_int_index(index, shape, op_name):
|
||||
if index < -shape or index >= shape:
|
||||
raise IndexError(f"In the \"{op_name}\", the index should in the range [-{shape}, {shape-1}] to fit "
|
||||
f"the corresponding dim length, but get {index}.")
|
||||
if index < 0:
|
||||
index += shape
|
||||
return index
|
||||
|
||||
|
||||
@constexpr
|
||||
def transform_sequence_index(sequence_index, shape, op_name):
|
||||
"""transform list or tuple with integer and boolean to tuple with integer index"""
|
||||
bool_count = len(list(filter(lambda index: isinstance(index, bool), sequence_index)))
|
||||
int_count = len(list(filter(lambda index: isinstance(index, int), sequence_index)))-bool_count
|
||||
if int_count == 0:
|
||||
if bool_count == shape:
|
||||
list_index = list(filter(lambda i: sequence_index[i], range(bool_count)))
|
||||
else:
|
||||
raise IndexError("The boolean array should have the same length with the corresponding dimensiton")
|
||||
else:
|
||||
list_index = [int(index) for index in sequence_index]
|
||||
for i, index in enumerate(list_index):
|
||||
list_index[i] = check_and_transform_int_index(index, shape, op_name)
|
||||
sub_tuple_index = tuple(list_index)
|
||||
return sub_tuple_index
|
||||
|
||||
|
||||
@constexpr
|
||||
def convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_name):
|
||||
"""Convert a slice to a tensor."""
|
||||
|
@ -702,6 +744,18 @@ def get_pos_of_int_index(indexes_types):
|
|||
return int_positions
|
||||
|
||||
|
||||
@constexpr
|
||||
def get_pos_of_int_sequence(indexes_types):
|
||||
"""Get int and sequence index positions from the mixed tensors index."""
|
||||
int_positions, sequence_positions = [], []
|
||||
for i, index_type in enumerate(indexes_types):
|
||||
if isinstance(index_type, mstype.Int):
|
||||
int_positions.append(i)
|
||||
elif isinstance(index_type, (tuple, list)):
|
||||
sequence_positions.append(i)
|
||||
return int_positions, sequence_positions
|
||||
|
||||
|
||||
@constexpr
|
||||
def separate_mixed_tensors_index(indexes_types, op_name):
|
||||
"""Separate the position information of tensor and slice and ellipsis from the mixed tensors index."""
|
||||
|
|
|
@ -206,21 +206,6 @@ def _tensor_getitem_by_tensor(data, tensor_index):
|
|||
return compile_utils.tensor_index_by_tensor(data, tensor_index)
|
||||
|
||||
|
||||
@getitem.register("Tensor", "Tuple")
|
||||
def _tensor_getitem_by_tuple(data, tuple_index):
|
||||
"""
|
||||
Getting item of tensor by tuple.
|
||||
|
||||
Inputs:
|
||||
data (Tensor): A tensor.
|
||||
tuple_index (tuple): Index in tuple which include ellipsis, slice, int, Tensor, None, list, tuple.
|
||||
|
||||
Outputs:
|
||||
Tensor, element type is the same as the element type of data.
|
||||
"""
|
||||
return compile_utils.tensor_index_by_tuple(data, tuple_index)
|
||||
|
||||
|
||||
@getitem.register("Tensor", "Ellipsis")
|
||||
def _tensor_getitem_by_ellipsis(data, ellipsis_index):
|
||||
"""
|
||||
|
@ -249,3 +234,18 @@ def _tensor_getitem_by_list(data, list_index):
|
|||
Tensor ,same as data.
|
||||
"""
|
||||
return compile_utils.tensor_index_by_list(data, list_index)
|
||||
|
||||
|
||||
@getitem.register("Tensor", "Tuple")
|
||||
def _tensor_getitem_by_tuple(data, tuple_index):
|
||||
"""
|
||||
Getting item of tensor by tuple.
|
||||
|
||||
Inputs:
|
||||
data (Tensor): A tensor.
|
||||
tuple_index (tuple): Index in tuple which include ellipsis, slice, int, Tensor, None, list, tuple.
|
||||
|
||||
Outputs:
|
||||
Tensor, element type is the same as the element type of data.
|
||||
"""
|
||||
return compile_utils.tensor_index_by_tuple(data, tuple_index)
|
||||
|
|
|
@ -21,61 +21,64 @@ from mindspore import dtype as mstype
|
|||
from mindspore.nn import Cell
|
||||
|
||||
|
||||
class NetWorkFancyIndexBoolean(Cell):
|
||||
class NetWorkFancyIndex(Cell):
|
||||
def __init__(self, index):
|
||||
super(NetWorkFancyIndexBoolean, self).__init__()
|
||||
super(NetWorkFancyIndex, self).__init__()
|
||||
self.index = index
|
||||
|
||||
def construct(self, tensor):
|
||||
return tensor[self.index]
|
||||
|
||||
|
||||
class NetWorkFancyIndexInterger(Cell):
|
||||
def __init__(self, index):
|
||||
super(NetWorkFancyIndexInterger, self).__init__()
|
||||
self.index = index
|
||||
|
||||
def construct(self, tensor):
|
||||
return tensor[self.index]
|
||||
|
||||
|
||||
class NetWorkFancyIndexIntergerBooleanMixed(Cell):
|
||||
def __init__(self, index):
|
||||
super(NetWorkFancyIndexIntergerBooleanMixed, self).__init__()
|
||||
self.index = index
|
||||
|
||||
def construct(self, tensor):
|
||||
return tensor[self.index]
|
||||
|
||||
|
||||
def test_tensor_fancy_index_integer_list():
|
||||
def test_tensor_fancy_index_integer_list_graph():
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
index = [0, 2, 1]
|
||||
net = NetWorkFancyIndexBoolean(index)
|
||||
net = NetWorkFancyIndex(index)
|
||||
input_np = np.arange(60).reshape(3, 4, 5)
|
||||
input_me = Tensor(input_np, dtype=mstype.float32)
|
||||
output_me = net(input_me).asnumpy()
|
||||
output_np = input_np[index]
|
||||
assert np.allclose(output_np, output_me, 0, 0)
|
||||
net(input_me)
|
||||
|
||||
|
||||
def test_tensor_fancy_boolean_list():
|
||||
def test_tensor_fancy_boolean_list_graph():
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
index = [True, True, False]
|
||||
net = NetWorkFancyIndexInterger(index)
|
||||
net = NetWorkFancyIndex(index)
|
||||
input_np = np.arange(60).reshape(3, 4, 5)
|
||||
input_me = Tensor(input_np, dtype=mstype.float32)
|
||||
output_me = net(input_me).asnumpy()
|
||||
output_np = input_np[index]
|
||||
assert np.allclose(output_np, output_me, 0, 0)
|
||||
net(input_me)
|
||||
|
||||
|
||||
def test_tensor_fancy_integer_boolean_list():
|
||||
def test_tensor_fancy_integer_boolean_list_graph():
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
index = [1, 2, True, False]
|
||||
net = NetWorkFancyIndexIntergerBooleanMixed(index)
|
||||
net = NetWorkFancyIndex(index)
|
||||
input_np = np.arange(60).reshape(3, 4, 5)
|
||||
input_me = Tensor(input_np, dtype=mstype.float32)
|
||||
output_me = net(input_me).asnumpy()
|
||||
output_np = input_np[index]
|
||||
assert np.allclose(output_np, output_me, 0, 0)
|
||||
net(input_me)
|
||||
|
||||
|
||||
def test_tensor_fancy_integer_list_mixed_graph():
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
index = (1, [2, 1, 3], slice(1, 3, 1), ..., 4)
|
||||
net = NetWorkFancyIndex(index)
|
||||
input_np = np.arange(3*4*5*6*7*8).reshape(3, 4, 5, 6, 7, 8)
|
||||
input_me = Tensor(input_np, dtype=mstype.float32)
|
||||
net(input_me)
|
||||
|
||||
|
||||
def test_tensor_fancy_integer_tuple_mixed_graph():
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
index = (1, (2, 1, 3), slice(1, 3, 1), ..., 4)
|
||||
net = NetWorkFancyIndex(index)
|
||||
input_np = np.arange(3*4*5*6*7*8).reshape(3, 4, 5, 6, 7, 8)
|
||||
input_me = Tensor(input_np, dtype=mstype.float32)
|
||||
net(input_me)
|
||||
|
||||
|
||||
def test_tensor_fancy_integer_list_tuple_mixed_graph():
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
index = (1, [2, 1, 3], (3, 2, 1), slice(1, 3, 1), ..., 4)
|
||||
net = NetWorkFancyIndex(index)
|
||||
input_np = np.arange(3*4*5*6*7*8).reshape(3, 4, 5, 6, 7, 8)
|
||||
input_me = Tensor(input_np, dtype=mstype.float32)
|
||||
net(input_me)
|
Loading…
Reference in New Issue