forked from mindspore-Ecosystem/mindspore
!9737 complete fancy index getitem
From: @yepei6 Reviewed-by: Signed-off-by:
This commit is contained in:
commit
a9f2de8307
|
@ -66,8 +66,8 @@ def _generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name):
|
||||||
tuple_len = len(tuple_index)
|
tuple_len = len(tuple_index)
|
||||||
for i in range(tuple_len):
|
for i in range(tuple_len):
|
||||||
if i in int_positions:
|
if i in int_positions:
|
||||||
tuple_index_new += (F.scalar_to_tensor(tuple_index[i] if tuple_index[i] >= 0 else tuple_index[i] + \
|
tuple_index_new += (F.scalar_to_tensor(tuple_index[i] if tuple_index[i] >= 0 else tuple_index[i] +
|
||||||
data_shape[i], mstype.int32),)
|
data_shape[i], mstype.int32),)
|
||||||
else:
|
else:
|
||||||
tuple_index_new += (tuple_index[i],)
|
tuple_index_new += (tuple_index[i],)
|
||||||
indexes_types = hyper_map(F.typeof, tuple_index_new)
|
indexes_types = hyper_map(F.typeof, tuple_index_new)
|
||||||
|
@ -95,24 +95,16 @@ def _generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name):
|
||||||
index_tensor_new_shape = const_utils.compute_new_shape(broadcast_shape, indexes_shapes_info)
|
index_tensor_new_shape = const_utils.compute_new_shape(broadcast_shape, indexes_shapes_info)
|
||||||
for i in range(tuple_index_size):
|
for i in range(tuple_index_size):
|
||||||
if i in tensor_positions:
|
if i in tensor_positions:
|
||||||
transform_tensor = _transform_indexing_tensor(broadcast_shape,
|
transform_tensor = _transform_indexing_tensor(
|
||||||
final_shape,
|
broadcast_shape, final_shape, index_tensor_new_shape, tuple_index_new[i])
|
||||||
index_tensor_new_shape,
|
|
||||||
tuple_index_new[i])
|
|
||||||
final_index_tensors.append(transform_tensor)
|
final_index_tensors.append(transform_tensor)
|
||||||
if i in slice_positions:
|
if i in slice_positions:
|
||||||
slice_tensor = const_utils.convert_slice_to_tensor(slice_number,
|
slice_tensor = const_utils.convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_name)
|
||||||
final_shape,
|
|
||||||
indexes_shapes_info,
|
|
||||||
op_name)
|
|
||||||
final_index_tensors.append(slice_tensor)
|
final_index_tensors.append(slice_tensor)
|
||||||
slice_number += 1
|
slice_number += 1
|
||||||
if i == ellipsis_position:
|
if i == ellipsis_position:
|
||||||
ellipsis_tensors = const_utils.convert_ellipsis_to_tensors(slice_number,
|
ellipsis_tensors = const_utils.convert_ellipsis_to_tensors(
|
||||||
ellipsis_occupied_dims,
|
slice_number, ellipsis_occupied_dims, final_shape, indexes_shapes_info, op_name)
|
||||||
final_shape,
|
|
||||||
indexes_shapes_info,
|
|
||||||
op_name)
|
|
||||||
for ele in ellipsis_tensors:
|
for ele in ellipsis_tensors:
|
||||||
final_index_tensors.append(ele)
|
final_index_tensors.append(ele)
|
||||||
slice_number += ellipsis_occupied_dims
|
slice_number += ellipsis_occupied_dims
|
||||||
|
@ -266,12 +258,13 @@ def _tensor_index_by_tuple_slice(data, tuple_index):
|
||||||
return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides)
|
return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides)
|
||||||
|
|
||||||
|
|
||||||
def tensor_expand_dims(data, tuple_index):
|
def tensor_index_by_list(data, list_index):
|
||||||
"""Expand tensor dims by tuple contains None and replace the None by slice in tuple_index """
|
"""Tensor getitem by list of int and bool"""
|
||||||
none_positions, tuple_index_without_none = const_utils.split_tuple_index_for_none(tuple_index)
|
data_shape = F.shape(data)
|
||||||
for position in none_positions:
|
const_utils.check_list_index_type(list_index)
|
||||||
data = F.expand_dims(data, position)
|
list_index = const_utils.transform_list(list_index, data_shape[0])
|
||||||
return data, tuple_index_without_none
|
tensor_index = const_utils.convert_list_to_tensor(list_index)
|
||||||
|
return F.gather(data, tensor_index, 0)
|
||||||
|
|
||||||
|
|
||||||
def tensor_index_by_tuple(data, tuple_index):
|
def tensor_index_by_tuple(data, tuple_index):
|
||||||
|
|
|
@ -128,12 +128,14 @@ def is_same_type(inst, type_):
|
||||||
"""
|
"""
|
||||||
return inst == type_
|
return inst == type_
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def check_valid_dim(dim, name):
|
def check_valid_dim(dim, name):
|
||||||
if dim not in (1, 2):
|
if dim not in (1, 2):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"For {name}, inputs dim must be 1d or 2d")
|
f"For {name}, inputs dim must be 1d or 2d")
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def check_valid_type(data_type, value_type, name):
|
def check_valid_type(data_type, value_type, name):
|
||||||
if not data_type in value_type:
|
if not data_type in value_type:
|
||||||
|
@ -422,6 +424,42 @@ def compute_new_shape(origin_shape, indexes_shapes_info):
|
||||||
return tuple(new_shape)
|
return tuple(new_shape)
|
||||||
|
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def check_list_index_type(list_index):
|
||||||
|
"""check if the item's type of list_index is bool or int"""
|
||||||
|
if not all([isinstance(index, (int, bool)) for index in list_index]):
|
||||||
|
raise IndexError(
|
||||||
|
f"Tensor only support 'integer' or 'boolean' array(list/tuple), but got {type(index)} in array")
|
||||||
|
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def transform_list(list_index, shape):
|
||||||
|
"""transfor list_index from int or bool to int"""
|
||||||
|
bool_count = len(list(filter(lambda index: isinstance(index, bool), list_index)))
|
||||||
|
int_count = len(list(filter(lambda index: isinstance(index, int), list_index)))-bool_count
|
||||||
|
if int_count == 0:
|
||||||
|
if bool_count == shape:
|
||||||
|
list_index = list(filter(lambda i: list_index[i], range(bool_count)))
|
||||||
|
else:
|
||||||
|
raise IndexError("The boolean array should have the same length with the corresponding dimensiton")
|
||||||
|
else:
|
||||||
|
list_index = [int(index) for index in list_index]
|
||||||
|
for i, index in enumerate(list_index):
|
||||||
|
if index < -shape or index >= shape:
|
||||||
|
raise IndexError(f"The index should in the range [-{shape}, {shape-1}] to fit the corresponding dim "
|
||||||
|
f"length, but get {index}.")
|
||||||
|
if index < 0:
|
||||||
|
index += shape
|
||||||
|
list_index[i] = index
|
||||||
|
return list_index
|
||||||
|
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def convert_list_to_tensor(list_index):
|
||||||
|
"""convert the list_index to tensor_index with mstype.int64 dtype"""
|
||||||
|
return Tensor(list_index, mstype.int64)
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def convert_int_to_slice(tuple_indexes):
|
def convert_int_to_slice(tuple_indexes):
|
||||||
tuple_indexes_new = tuple(slice(i, i+1, 1) for i in tuple_indexes)
|
tuple_indexes_new = tuple(slice(i, i+1, 1) for i in tuple_indexes)
|
||||||
|
|
|
@ -234,3 +234,18 @@ def _tensor_getitem_by_ellipsis(data, ellipsis_index):
|
||||||
Tensor, same as data.
|
Tensor, same as data.
|
||||||
"""
|
"""
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
@getitem.register("Tensor", "List")
|
||||||
|
def _tensor_getitem_by_list(data, list_index):
|
||||||
|
"""
|
||||||
|
Getting item of tensor by list.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
data (Tensor): A tensor
|
||||||
|
list_index (List): A list object.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor ,same as data.
|
||||||
|
"""
|
||||||
|
return compile_utils.tensor_index_by_list(data, list_index)
|
||||||
|
|
|
@ -0,0 +1,81 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
""" test_tensor_slice """
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore import dtype as mstype
|
||||||
|
from mindspore.nn import Cell
|
||||||
|
|
||||||
|
|
||||||
|
class NetWorkFancyIndexBoolean(Cell):
|
||||||
|
def __init__(self, index):
|
||||||
|
super(NetWorkFancyIndexBoolean, self).__init__()
|
||||||
|
self.index = index
|
||||||
|
|
||||||
|
def construct(self, tensor):
|
||||||
|
return tensor[self.index]
|
||||||
|
|
||||||
|
|
||||||
|
class NetWorkFancyIndexInterger(Cell):
|
||||||
|
def __init__(self, index):
|
||||||
|
super(NetWorkFancyIndexInterger, self).__init__()
|
||||||
|
self.index = index
|
||||||
|
|
||||||
|
def construct(self, tensor):
|
||||||
|
return tensor[self.index]
|
||||||
|
|
||||||
|
|
||||||
|
class NetWorkFancyIndexIntergerBooleanMixed(Cell):
|
||||||
|
def __init__(self, index):
|
||||||
|
super(NetWorkFancyIndexIntergerBooleanMixed, self).__init__()
|
||||||
|
self.index = index
|
||||||
|
|
||||||
|
def construct(self, tensor):
|
||||||
|
return tensor[self.index]
|
||||||
|
|
||||||
|
|
||||||
|
def test_tensor_fancy_index_integer_list():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||||
|
index = [0, 2, 1]
|
||||||
|
net = NetWorkFancyIndexBoolean(index)
|
||||||
|
input_np = np.arange(60).reshape(3, 4, 5)
|
||||||
|
input_me = Tensor(input_np, dtype=mstype.float32)
|
||||||
|
output_me = net(input_me).asnumpy()
|
||||||
|
output_np = input_np[index]
|
||||||
|
assert np.allclose(output_np, output_me, 0, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tensor_fancy_boolean_list():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||||
|
index = [True, True, False]
|
||||||
|
net = NetWorkFancyIndexInterger(index)
|
||||||
|
input_np = np.arange(60).reshape(3, 4, 5)
|
||||||
|
input_me = Tensor(input_np, dtype=mstype.float32)
|
||||||
|
output_me = net(input_me).asnumpy()
|
||||||
|
output_np = input_np[index]
|
||||||
|
assert np.allclose(output_np, output_me, 0, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tensor_fancy_integer_boolean_list():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||||
|
index = [1, 2, True, False]
|
||||||
|
net = NetWorkFancyIndexIntergerBooleanMixed(index)
|
||||||
|
input_np = np.arange(60).reshape(3, 4, 5)
|
||||||
|
input_me = Tensor(input_np, dtype=mstype.float32)
|
||||||
|
output_me = net(input_me).asnumpy()
|
||||||
|
output_np = input_np[index]
|
||||||
|
assert np.allclose(output_np, output_me, 0, 0)
|
Loading…
Reference in New Issue