complete fancy index getitem

This commit is contained in:
Payne 2020-12-14 10:21:43 +08:00
parent f62621a23d
commit 4c8f0914d0
4 changed files with 148 additions and 21 deletions

View File

@ -66,8 +66,8 @@ def _generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name):
tuple_len = len(tuple_index)
for i in range(tuple_len):
if i in int_positions:
tuple_index_new += (F.scalar_to_tensor(tuple_index[i] if tuple_index[i] >= 0 else tuple_index[i] + \
data_shape[i], mstype.int32),)
tuple_index_new += (F.scalar_to_tensor(tuple_index[i] if tuple_index[i] >= 0 else tuple_index[i] +
data_shape[i], mstype.int32),)
else:
tuple_index_new += (tuple_index[i],)
indexes_types = hyper_map(F.typeof, tuple_index_new)
@ -95,24 +95,16 @@ def _generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name):
index_tensor_new_shape = const_utils.compute_new_shape(broadcast_shape, indexes_shapes_info)
for i in range(tuple_index_size):
if i in tensor_positions:
transform_tensor = _transform_indexing_tensor(broadcast_shape,
final_shape,
index_tensor_new_shape,
tuple_index_new[i])
transform_tensor = _transform_indexing_tensor(
broadcast_shape, final_shape, index_tensor_new_shape, tuple_index_new[i])
final_index_tensors.append(transform_tensor)
if i in slice_positions:
slice_tensor = const_utils.convert_slice_to_tensor(slice_number,
final_shape,
indexes_shapes_info,
op_name)
slice_tensor = const_utils.convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_name)
final_index_tensors.append(slice_tensor)
slice_number += 1
if i == ellipsis_position:
ellipsis_tensors = const_utils.convert_ellipsis_to_tensors(slice_number,
ellipsis_occupied_dims,
final_shape,
indexes_shapes_info,
op_name)
ellipsis_tensors = const_utils.convert_ellipsis_to_tensors(
slice_number, ellipsis_occupied_dims, final_shape, indexes_shapes_info, op_name)
for ele in ellipsis_tensors:
final_index_tensors.append(ele)
slice_number += ellipsis_occupied_dims
@ -266,12 +258,13 @@ def _tensor_index_by_tuple_slice(data, tuple_index):
return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides)
def tensor_expand_dims(data, tuple_index):
"""Expand tensor dims by tuple contains None and replace the None by slice in tuple_index """
none_positions, tuple_index_without_none = const_utils.split_tuple_index_for_none(tuple_index)
for position in none_positions:
data = F.expand_dims(data, position)
return data, tuple_index_without_none
def tensor_index_by_list(data, list_index):
"""Tensor getitem by list of int and bool"""
data_shape = F.shape(data)
const_utils.check_list_index_type(list_index)
list_index = const_utils.transform_list(list_index, data_shape[0])
tensor_index = const_utils.convert_list_to_tensor(list_index)
return F.gather(data, tensor_index, 0)
def tensor_index_by_tuple(data, tuple_index):

View File

@ -128,12 +128,14 @@ def is_same_type(inst, type_):
"""
return inst == type_
@constexpr
def check_valid_dim(dim, name):
if dim not in (1, 2):
raise ValueError(
f"For {name}, inputs dim must be 1d or 2d")
@constexpr
def check_valid_type(data_type, value_type, name):
if not data_type in value_type:
@ -422,6 +424,42 @@ def compute_new_shape(origin_shape, indexes_shapes_info):
return tuple(new_shape)
@constexpr
def check_list_index_type(list_index):
"""check if the item's type of list_index is bool or int"""
if not all([isinstance(index, (int, bool)) for index in list_index]):
raise IndexError(
f"Tensor only support 'integer' or 'boolean' array(list/tuple), but got {type(index)} in array")
@constexpr
def transform_list(list_index, shape):
"""transfor list_index from int or bool to int"""
bool_count = len(list(filter(lambda index: isinstance(index, bool), list_index)))
int_count = len(list(filter(lambda index: isinstance(index, int), list_index)))-bool_count
if int_count == 0:
if bool_count == shape:
list_index = list(filter(lambda i: list_index[i], range(bool_count)))
else:
raise IndexError("The boolean array should have the same length with the corresponding dimensiton")
else:
list_index = [int(index) for index in list_index]
for i, index in enumerate(list_index):
if index < -shape or index >= shape:
raise IndexError(f"The index should in the range [-{shape}, {shape-1}] to fit the corresponding dim "
f"length, but get {index}.")
if index < 0:
index += shape
list_index[i] = index
return list_index
@constexpr
def convert_list_to_tensor(list_index):
"""convert the list_index to tensor_index with mstype.int64 dtype"""
return Tensor(list_index, mstype.int64)
@constexpr
def convert_int_to_slice(tuple_indexes):
tuple_indexes_new = tuple(slice(i, i+1, 1) for i in tuple_indexes)

View File

@ -234,3 +234,18 @@ def _tensor_getitem_by_ellipsis(data, ellipsis_index):
Tensor, same as data.
"""
return data
@getitem.register("Tensor", "List")
def _tensor_getitem_by_list(data, list_index):
"""
Getting item of tensor by list.
Inputs:
data (Tensor): A tensor
list_index (List): A list object.
Outputs:
Tensor ,same as data.
"""
return compile_utils.tensor_index_by_list(data, list_index)

View File

@ -0,0 +1,81 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test_tensor_slice """
import numpy as np
from mindspore import Tensor
from mindspore import context
from mindspore import dtype as mstype
from mindspore.nn import Cell
class NetWorkFancyIndexBoolean(Cell):
def __init__(self, index):
super(NetWorkFancyIndexBoolean, self).__init__()
self.index = index
def construct(self, tensor):
return tensor[self.index]
class NetWorkFancyIndexInterger(Cell):
def __init__(self, index):
super(NetWorkFancyIndexInterger, self).__init__()
self.index = index
def construct(self, tensor):
return tensor[self.index]
class NetWorkFancyIndexIntergerBooleanMixed(Cell):
def __init__(self, index):
super(NetWorkFancyIndexIntergerBooleanMixed, self).__init__()
self.index = index
def construct(self, tensor):
return tensor[self.index]
def test_tensor_fancy_index_integer_list():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
index = [0, 2, 1]
net = NetWorkFancyIndexBoolean(index)
input_np = np.arange(60).reshape(3, 4, 5)
input_me = Tensor(input_np, dtype=mstype.float32)
output_me = net(input_me).asnumpy()
output_np = input_np[index]
assert np.allclose(output_np, output_me, 0, 0)
def test_tensor_fancy_boolean_list():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
index = [True, True, False]
net = NetWorkFancyIndexInterger(index)
input_np = np.arange(60).reshape(3, 4, 5)
input_me = Tensor(input_np, dtype=mstype.float32)
output_me = net(input_me).asnumpy()
output_np = input_np[index]
assert np.allclose(output_np, output_me, 0, 0)
def test_tensor_fancy_integer_boolean_list():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
index = [1, 2, True, False]
net = NetWorkFancyIndexIntergerBooleanMixed(index)
input_np = np.arange(60).reshape(3, 4, 5)
input_me = Tensor(input_np, dtype=mstype.float32)
output_me = net(input_me).asnumpy()
output_np = input_np[index]
assert np.allclose(output_np, output_me, 0, 0)