forked from mindspore-Ecosystem/mindspore
complete fancy index getitem
This commit is contained in:
parent
f62621a23d
commit
4c8f0914d0
|
@ -66,8 +66,8 @@ def _generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name):
|
|||
tuple_len = len(tuple_index)
|
||||
for i in range(tuple_len):
|
||||
if i in int_positions:
|
||||
tuple_index_new += (F.scalar_to_tensor(tuple_index[i] if tuple_index[i] >= 0 else tuple_index[i] + \
|
||||
data_shape[i], mstype.int32),)
|
||||
tuple_index_new += (F.scalar_to_tensor(tuple_index[i] if tuple_index[i] >= 0 else tuple_index[i] +
|
||||
data_shape[i], mstype.int32),)
|
||||
else:
|
||||
tuple_index_new += (tuple_index[i],)
|
||||
indexes_types = hyper_map(F.typeof, tuple_index_new)
|
||||
|
@ -95,24 +95,16 @@ def _generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name):
|
|||
index_tensor_new_shape = const_utils.compute_new_shape(broadcast_shape, indexes_shapes_info)
|
||||
for i in range(tuple_index_size):
|
||||
if i in tensor_positions:
|
||||
transform_tensor = _transform_indexing_tensor(broadcast_shape,
|
||||
final_shape,
|
||||
index_tensor_new_shape,
|
||||
tuple_index_new[i])
|
||||
transform_tensor = _transform_indexing_tensor(
|
||||
broadcast_shape, final_shape, index_tensor_new_shape, tuple_index_new[i])
|
||||
final_index_tensors.append(transform_tensor)
|
||||
if i in slice_positions:
|
||||
slice_tensor = const_utils.convert_slice_to_tensor(slice_number,
|
||||
final_shape,
|
||||
indexes_shapes_info,
|
||||
op_name)
|
||||
slice_tensor = const_utils.convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_name)
|
||||
final_index_tensors.append(slice_tensor)
|
||||
slice_number += 1
|
||||
if i == ellipsis_position:
|
||||
ellipsis_tensors = const_utils.convert_ellipsis_to_tensors(slice_number,
|
||||
ellipsis_occupied_dims,
|
||||
final_shape,
|
||||
indexes_shapes_info,
|
||||
op_name)
|
||||
ellipsis_tensors = const_utils.convert_ellipsis_to_tensors(
|
||||
slice_number, ellipsis_occupied_dims, final_shape, indexes_shapes_info, op_name)
|
||||
for ele in ellipsis_tensors:
|
||||
final_index_tensors.append(ele)
|
||||
slice_number += ellipsis_occupied_dims
|
||||
|
@ -266,12 +258,13 @@ def _tensor_index_by_tuple_slice(data, tuple_index):
|
|||
return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides)
|
||||
|
||||
|
||||
def tensor_expand_dims(data, tuple_index):
|
||||
"""Expand tensor dims by tuple contains None and replace the None by slice in tuple_index """
|
||||
none_positions, tuple_index_without_none = const_utils.split_tuple_index_for_none(tuple_index)
|
||||
for position in none_positions:
|
||||
data = F.expand_dims(data, position)
|
||||
return data, tuple_index_without_none
|
||||
def tensor_index_by_list(data, list_index):
|
||||
"""Tensor getitem by list of int and bool"""
|
||||
data_shape = F.shape(data)
|
||||
const_utils.check_list_index_type(list_index)
|
||||
list_index = const_utils.transform_list(list_index, data_shape[0])
|
||||
tensor_index = const_utils.convert_list_to_tensor(list_index)
|
||||
return F.gather(data, tensor_index, 0)
|
||||
|
||||
|
||||
def tensor_index_by_tuple(data, tuple_index):
|
||||
|
|
|
@ -128,12 +128,14 @@ def is_same_type(inst, type_):
|
|||
"""
|
||||
return inst == type_
|
||||
|
||||
|
||||
@constexpr
|
||||
def check_valid_dim(dim, name):
|
||||
if dim not in (1, 2):
|
||||
raise ValueError(
|
||||
f"For {name}, inputs dim must be 1d or 2d")
|
||||
|
||||
|
||||
@constexpr
|
||||
def check_valid_type(data_type, value_type, name):
|
||||
if not data_type in value_type:
|
||||
|
@ -422,6 +424,42 @@ def compute_new_shape(origin_shape, indexes_shapes_info):
|
|||
return tuple(new_shape)
|
||||
|
||||
|
||||
@constexpr
|
||||
def check_list_index_type(list_index):
|
||||
"""check if the item's type of list_index is bool or int"""
|
||||
if not all([isinstance(index, (int, bool)) for index in list_index]):
|
||||
raise IndexError(
|
||||
f"Tensor only support 'integer' or 'boolean' array(list/tuple), but got {type(index)} in array")
|
||||
|
||||
|
||||
@constexpr
|
||||
def transform_list(list_index, shape):
|
||||
"""transfor list_index from int or bool to int"""
|
||||
bool_count = len(list(filter(lambda index: isinstance(index, bool), list_index)))
|
||||
int_count = len(list(filter(lambda index: isinstance(index, int), list_index)))-bool_count
|
||||
if int_count == 0:
|
||||
if bool_count == shape:
|
||||
list_index = list(filter(lambda i: list_index[i], range(bool_count)))
|
||||
else:
|
||||
raise IndexError("The boolean array should have the same length with the corresponding dimensiton")
|
||||
else:
|
||||
list_index = [int(index) for index in list_index]
|
||||
for i, index in enumerate(list_index):
|
||||
if index < -shape or index >= shape:
|
||||
raise IndexError(f"The index should in the range [-{shape}, {shape-1}] to fit the corresponding dim "
|
||||
f"length, but get {index}.")
|
||||
if index < 0:
|
||||
index += shape
|
||||
list_index[i] = index
|
||||
return list_index
|
||||
|
||||
|
||||
@constexpr
|
||||
def convert_list_to_tensor(list_index):
|
||||
"""convert the list_index to tensor_index with mstype.int64 dtype"""
|
||||
return Tensor(list_index, mstype.int64)
|
||||
|
||||
|
||||
@constexpr
|
||||
def convert_int_to_slice(tuple_indexes):
|
||||
tuple_indexes_new = tuple(slice(i, i+1, 1) for i in tuple_indexes)
|
||||
|
|
|
@ -234,3 +234,18 @@ def _tensor_getitem_by_ellipsis(data, ellipsis_index):
|
|||
Tensor, same as data.
|
||||
"""
|
||||
return data
|
||||
|
||||
|
||||
@getitem.register("Tensor", "List")
|
||||
def _tensor_getitem_by_list(data, list_index):
|
||||
"""
|
||||
Getting item of tensor by list.
|
||||
|
||||
Inputs:
|
||||
data (Tensor): A tensor
|
||||
list_index (List): A list object.
|
||||
|
||||
Outputs:
|
||||
Tensor ,same as data.
|
||||
"""
|
||||
return compile_utils.tensor_index_by_list(data, list_index)
|
||||
|
|
|
@ -0,0 +1,81 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test_tensor_slice """
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore.nn import Cell
|
||||
|
||||
|
||||
class NetWorkFancyIndexBoolean(Cell):
|
||||
def __init__(self, index):
|
||||
super(NetWorkFancyIndexBoolean, self).__init__()
|
||||
self.index = index
|
||||
|
||||
def construct(self, tensor):
|
||||
return tensor[self.index]
|
||||
|
||||
|
||||
class NetWorkFancyIndexInterger(Cell):
|
||||
def __init__(self, index):
|
||||
super(NetWorkFancyIndexInterger, self).__init__()
|
||||
self.index = index
|
||||
|
||||
def construct(self, tensor):
|
||||
return tensor[self.index]
|
||||
|
||||
|
||||
class NetWorkFancyIndexIntergerBooleanMixed(Cell):
|
||||
def __init__(self, index):
|
||||
super(NetWorkFancyIndexIntergerBooleanMixed, self).__init__()
|
||||
self.index = index
|
||||
|
||||
def construct(self, tensor):
|
||||
return tensor[self.index]
|
||||
|
||||
|
||||
def test_tensor_fancy_index_integer_list():
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
index = [0, 2, 1]
|
||||
net = NetWorkFancyIndexBoolean(index)
|
||||
input_np = np.arange(60).reshape(3, 4, 5)
|
||||
input_me = Tensor(input_np, dtype=mstype.float32)
|
||||
output_me = net(input_me).asnumpy()
|
||||
output_np = input_np[index]
|
||||
assert np.allclose(output_np, output_me, 0, 0)
|
||||
|
||||
|
||||
def test_tensor_fancy_boolean_list():
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
index = [True, True, False]
|
||||
net = NetWorkFancyIndexInterger(index)
|
||||
input_np = np.arange(60).reshape(3, 4, 5)
|
||||
input_me = Tensor(input_np, dtype=mstype.float32)
|
||||
output_me = net(input_me).asnumpy()
|
||||
output_np = input_np[index]
|
||||
assert np.allclose(output_np, output_me, 0, 0)
|
||||
|
||||
|
||||
def test_tensor_fancy_integer_boolean_list():
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
index = [1, 2, True, False]
|
||||
net = NetWorkFancyIndexIntergerBooleanMixed(index)
|
||||
input_np = np.arange(60).reshape(3, 4, 5)
|
||||
input_me = Tensor(input_np, dtype=mstype.float32)
|
||||
output_me = net(input_me).asnumpy()
|
||||
output_np = input_np[index]
|
||||
assert np.allclose(output_np, output_me, 0, 0)
|
Loading…
Reference in New Issue