diff --git a/mindspore/ops/composite/multitype_ops/_compile_utils.py b/mindspore/ops/composite/multitype_ops/_compile_utils.py index aea85bb8aa9..479c6bcee2d 100644 --- a/mindspore/ops/composite/multitype_ops/_compile_utils.py +++ b/mindspore/ops/composite/multitype_ops/_compile_utils.py @@ -66,8 +66,8 @@ def _generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name): tuple_len = len(tuple_index) for i in range(tuple_len): if i in int_positions: - tuple_index_new += (F.scalar_to_tensor(tuple_index[i] if tuple_index[i] >= 0 else tuple_index[i] + \ - data_shape[i], mstype.int32),) + tuple_index_new += (F.scalar_to_tensor(tuple_index[i] if tuple_index[i] >= 0 else tuple_index[i] + + data_shape[i], mstype.int32),) else: tuple_index_new += (tuple_index[i],) indexes_types = hyper_map(F.typeof, tuple_index_new) @@ -95,24 +95,16 @@ def _generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name): index_tensor_new_shape = const_utils.compute_new_shape(broadcast_shape, indexes_shapes_info) for i in range(tuple_index_size): if i in tensor_positions: - transform_tensor = _transform_indexing_tensor(broadcast_shape, - final_shape, - index_tensor_new_shape, - tuple_index_new[i]) + transform_tensor = _transform_indexing_tensor( + broadcast_shape, final_shape, index_tensor_new_shape, tuple_index_new[i]) final_index_tensors.append(transform_tensor) if i in slice_positions: - slice_tensor = const_utils.convert_slice_to_tensor(slice_number, - final_shape, - indexes_shapes_info, - op_name) + slice_tensor = const_utils.convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_name) final_index_tensors.append(slice_tensor) slice_number += 1 if i == ellipsis_position: - ellipsis_tensors = const_utils.convert_ellipsis_to_tensors(slice_number, - ellipsis_occupied_dims, - final_shape, - indexes_shapes_info, - op_name) + ellipsis_tensors = const_utils.convert_ellipsis_to_tensors( + slice_number, ellipsis_occupied_dims, final_shape, indexes_shapes_info, op_name) for ele in ellipsis_tensors: final_index_tensors.append(ele) slice_number += ellipsis_occupied_dims @@ -266,12 +258,13 @@ def _tensor_index_by_tuple_slice(data, tuple_index): return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides) -def tensor_expand_dims(data, tuple_index): - """Expand tensor dims by tuple contains None and replace the None by slice in tuple_index """ - none_positions, tuple_index_without_none = const_utils.split_tuple_index_for_none(tuple_index) - for position in none_positions: - data = F.expand_dims(data, position) - return data, tuple_index_without_none +def tensor_index_by_list(data, list_index): + """Tensor getitem by list of int and bool""" + data_shape = F.shape(data) + const_utils.check_list_index_type(list_index) + list_index = const_utils.transform_list(list_index, data_shape[0]) + tensor_index = const_utils.convert_list_to_tensor(list_index) + return F.gather(data, tensor_index, 0) def tensor_index_by_tuple(data, tuple_index): diff --git a/mindspore/ops/composite/multitype_ops/_constexpr_utils.py b/mindspore/ops/composite/multitype_ops/_constexpr_utils.py index b6f29aeeb3a..ddfdfc6927d 100644 --- a/mindspore/ops/composite/multitype_ops/_constexpr_utils.py +++ b/mindspore/ops/composite/multitype_ops/_constexpr_utils.py @@ -128,12 +128,14 @@ def is_same_type(inst, type_): """ return inst == type_ + @constexpr def check_valid_dim(dim, name): if dim not in (1, 2): raise ValueError( f"For {name}, inputs dim must be 1d or 2d") + @constexpr def check_valid_type(data_type, value_type, name): if not data_type in value_type: @@ -422,6 +424,42 @@ def compute_new_shape(origin_shape, indexes_shapes_info): return tuple(new_shape) +@constexpr +def check_list_index_type(list_index): + """check if the item's type of list_index is bool or int""" + if not all([isinstance(index, (int, bool)) for index in list_index]): + raise IndexError( + f"Tensor only support 'integer' or 'boolean' array(list/tuple), but got {type(index)} in array") + + +@constexpr +def transform_list(list_index, shape): + """transfor list_index from int or bool to int""" + bool_count = len(list(filter(lambda index: isinstance(index, bool), list_index))) + int_count = len(list(filter(lambda index: isinstance(index, int), list_index)))-bool_count + if int_count == 0: + if bool_count == shape: + list_index = list(filter(lambda i: list_index[i], range(bool_count))) + else: + raise IndexError("The boolean array should have the same length with the corresponding dimensiton") + else: + list_index = [int(index) for index in list_index] + for i, index in enumerate(list_index): + if index < -shape or index >= shape: + raise IndexError(f"The index should in the range [-{shape}, {shape-1}] to fit the corresponding dim " + f"length, but get {index}.") + if index < 0: + index += shape + list_index[i] = index + return list_index + + +@constexpr +def convert_list_to_tensor(list_index): + """convert the list_index to tensor_index with mstype.int64 dtype""" + return Tensor(list_index, mstype.int64) + + @constexpr def convert_int_to_slice(tuple_indexes): tuple_indexes_new = tuple(slice(i, i+1, 1) for i in tuple_indexes) diff --git a/mindspore/ops/composite/multitype_ops/getitem_impl.py b/mindspore/ops/composite/multitype_ops/getitem_impl.py index 3194981845f..20d43868496 100644 --- a/mindspore/ops/composite/multitype_ops/getitem_impl.py +++ b/mindspore/ops/composite/multitype_ops/getitem_impl.py @@ -234,3 +234,18 @@ def _tensor_getitem_by_ellipsis(data, ellipsis_index): Tensor, same as data. """ return data + + +@getitem.register("Tensor", "List") +def _tensor_getitem_by_list(data, list_index): + """ + Getting item of tensor by list. + + Inputs: + data (Tensor): A tensor + list_index (List): A list object. + + Outputs: + Tensor ,same as data. + """ + return compile_utils.tensor_index_by_list(data, list_index) diff --git a/tests/ut/python/ops/ test_tensor_fancy_index.py b/tests/ut/python/ops/ test_tensor_fancy_index.py new file mode 100644 index 00000000000..1d883a97702 --- /dev/null +++ b/tests/ut/python/ops/ test_tensor_fancy_index.py @@ -0,0 +1,81 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test_tensor_slice """ +import numpy as np + +from mindspore import Tensor +from mindspore import context +from mindspore import dtype as mstype +from mindspore.nn import Cell + + +class NetWorkFancyIndexBoolean(Cell): + def __init__(self, index): + super(NetWorkFancyIndexBoolean, self).__init__() + self.index = index + + def construct(self, tensor): + return tensor[self.index] + + +class NetWorkFancyIndexInterger(Cell): + def __init__(self, index): + super(NetWorkFancyIndexInterger, self).__init__() + self.index = index + + def construct(self, tensor): + return tensor[self.index] + + +class NetWorkFancyIndexIntergerBooleanMixed(Cell): + def __init__(self, index): + super(NetWorkFancyIndexIntergerBooleanMixed, self).__init__() + self.index = index + + def construct(self, tensor): + return tensor[self.index] + + +def test_tensor_fancy_index_integer_list(): + context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + index = [0, 2, 1] + net = NetWorkFancyIndexBoolean(index) + input_np = np.arange(60).reshape(3, 4, 5) + input_me = Tensor(input_np, dtype=mstype.float32) + output_me = net(input_me).asnumpy() + output_np = input_np[index] + assert np.allclose(output_np, output_me, 0, 0) + + +def test_tensor_fancy_boolean_list(): + context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + index = [True, True, False] + net = NetWorkFancyIndexInterger(index) + input_np = np.arange(60).reshape(3, 4, 5) + input_me = Tensor(input_np, dtype=mstype.float32) + output_me = net(input_me).asnumpy() + output_np = input_np[index] + assert np.allclose(output_np, output_me, 0, 0) + + +def test_tensor_fancy_integer_boolean_list(): + context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + index = [1, 2, True, False] + net = NetWorkFancyIndexIntergerBooleanMixed(index) + input_np = np.arange(60).reshape(3, 4, 5) + input_me = Tensor(input_np, dtype=mstype.float32) + output_me = net(input_me).asnumpy() + output_np = input_np[index] + assert np.allclose(output_np, output_me, 0, 0)