add the func to expand dims by None for Tensor

This commit is contained in:
Payne 2020-11-09 15:07:35 +08:00
parent ecc9f00c3c
commit 22a2f246c3
5 changed files with 41 additions and 45 deletions

View File

@ -19,6 +19,7 @@
namespace mindspore {
namespace kernel {
constexpr int MAX_DIMS = 8;
void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) {
CheckParam(kernel_node);
input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
@ -205,7 +206,7 @@ void SliceCPUKernel::CheckParam(const CNodePtr &kernel_node) const {
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but SliceCPUKernel needs 1 output.";
}
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
if (input_shape.size() > 4) {
if (input_shape.size() > MAX_DIMS) {
MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", but SliceCPUKernel olny support 4d or lower.";
}
if (input_shape.size() == 0) {

View File

@ -26,7 +26,7 @@
namespace mindspore {
namespace kernel {
constexpr int MAX_DIMS = 7;
constexpr int MAX_DIMS = 8;
template <typename T>
class StridedSliceGpuKernel : public GpuKernel {
public:
@ -51,7 +51,8 @@ class StridedSliceGpuKernel : public GpuKernel {
bool Init(const CNodePtr &kernel_node) override {
input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
if (input_shape_.size() > MAX_DIMS) {
MS_LOG(ERROR) << "StridedSlice support support dims less than " << input_shape_.size();
MS_LOG(ERROR) << "StridedSlice support dims no more than " << MAX_DIMS << ", but the input shape is "
<< input_shape_.size();
return false;
}

View File

@ -255,28 +255,38 @@ def tensor_index_by_tensor(data, tensor_index):
"the index tensor data type only support mstype.int32.")
def _tensor_index_by_tuple_slice(data, t):
def _tensor_index_by_tuple_slice(data, tuple_index):
"""Tensor getitem by a tuple of slice"""
shape = F.shape(data)
if len(t) > len(shape):
if len(tuple_index) > len(shape):
const_utils.raise_index_error("When tensor is indexed by a tuple, "
"the length of the tuple cannot be greater than the dimension of the tensor.")
begin_strides, end_strides, step_strides, shrink_axis_mask = \
const_utils.get_stride_info_from_tuple(shape, t)
const_utils.get_stride_info_from_tuple(shape, tuple_index)
return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides)
def tensor_expand_dims(data, tuple_index):
"""Expand tensor dims by tuple contains None and replace the None by slice in tuple_index """
none_positions, tuple_index_without_none = const_utils.split_tuple_index_for_none(tuple_index)
for position in none_positions:
data = F.expand_dims(data, position)
return data, tuple_index_without_none
def tensor_index_by_tuple(data, tuple_index):
"""Tensor getitem by tuple of various types"""
"""Tensor getitem by tuple of various types with None"""
# data, tuple_index_without_none = tensor_expand_dims(data, tuple_index)
tuple_index_without_none = tuple_index
if len(tuple_index) == 1:
return data[tuple_index[0]]
indexes_types = hyper_map(F.typeof, tuple_index)
index_elements_type = const_utils.tuple_index_tensor_cnt(indexes_types, const_utils.TENSOR_GETITEM)
if index_elements_type == const_utils.NO_TENSOR:
return _tensor_index_by_tuple_slice(data, tuple_index)
if index_elements_type == const_utils.ALL_TENSOR:
return _tensor_getitem_by_tuple_of_tensor(data, tuple_index)
return _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index)
return data[tuple_index_without_none[0]]
indexes_types = hyper_map(F.typeof, tuple_index_without_none)
tensor_cnt = const_utils.tuple_index_tensor_cnt(indexes_types, const_utils.TENSOR_GETITEM)
if tensor_cnt == const_utils.NO_TENSOR:
return _tensor_index_by_tuple_slice(data, tuple_index_without_none)
if tensor_cnt == const_utils.ALL_TENSOR:
return _tensor_getitem_by_tuple_of_tensor(data, tuple_index_without_none)
return _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index_without_none)
def _tensor_setitem(self, index, value):

View File

@ -66,6 +66,19 @@ def check_equal(param1, param2, msg="{},{}"):
return param1
@constexpr
def split_tuple_index_for_none(tuple_index):
"""return the none_positions and the tuple_index_without_none whose None index is replaced by slice."""
none_positions, tuple_index_without_none = (), ()
for idx, item in enumerate(tuple_index):
if item is None:
none_positions += (idx,)
tuple_index_without_none += (slice(None, None, None),)
else:
tuple_index_without_none += (item,)
return none_positions, tuple_index_without_none
@constexpr
def check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size):
"""Checks the shape and size of the sensor and value."""
@ -75,35 +88,6 @@ def check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size):
value_shape, data_shape))
@constexpr
def restrict_int_index(data_shape, tuple_indexes):
"""
Check the int index of tuple_indexes if value of index is out of the corresponding data shape
and turn the negtive int index to positive int index.
Inputs:
data_shape: the shape of data.
tuple_indexes(tuple[mstype.int32]): the tuple of index which will be used in setitem or getitem.
Outputs:
tuple_indexes_new(tuple[mstype.int32]): same purpose with tuple_indexes but only contain positive.
"""
if tuple_indexes is None:
return tuple_indexes
tuple_indexes_new = ()
for i, index in enumerate(tuple_indexes):
if isinstance(index, mstype.Int):
if index < -data_shape[i] or index >= data_shape[i]:
raise_index_error("The index is out of the data's special dimension range.")
elif index < 0:
tuple_indexes_new += (tuple_indexes[i]+data_shape[i],)
else:
tuple_indexes_new += (tuple_indexes[i],)
else:
tuple_indexes_new += (tuple_indexes[i],)
return tuple_indexes_new
@constexpr
def check_tensor_setitem_index(index, element_type=None):
"""Checks tuple index type of tensor assignment."""

View File

@ -213,7 +213,7 @@ def _tensor_getitem_by_tuple(data, tuple_index):
Inputs:
data (Tensor): A tensor.
tuple_index (tuple): Index in tuple.
tuple_index (tuple): Index in tuple which include ellipsis, slice, int, Tensor, None, list, tuple.
Outputs:
Tensor, element type is the same as the element type of data.