forked from mindspore-Ecosystem/mindspore
add the func to expand dims by None for Tensor
This commit is contained in:
parent
ecc9f00c3c
commit
22a2f246c3
|
@ -19,6 +19,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr int MAX_DIMS = 8;
|
||||
void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
CheckParam(kernel_node);
|
||||
input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
|
@ -205,7 +206,7 @@ void SliceCPUKernel::CheckParam(const CNodePtr &kernel_node) const {
|
|||
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but SliceCPUKernel needs 1 output.";
|
||||
}
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
if (input_shape.size() > 4) {
|
||||
if (input_shape.size() > MAX_DIMS) {
|
||||
MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", but SliceCPUKernel olny support 4d or lower.";
|
||||
}
|
||||
if (input_shape.size() == 0) {
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr int MAX_DIMS = 7;
|
||||
constexpr int MAX_DIMS = 8;
|
||||
template <typename T>
|
||||
class StridedSliceGpuKernel : public GpuKernel {
|
||||
public:
|
||||
|
@ -51,7 +51,8 @@ class StridedSliceGpuKernel : public GpuKernel {
|
|||
bool Init(const CNodePtr &kernel_node) override {
|
||||
input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
if (input_shape_.size() > MAX_DIMS) {
|
||||
MS_LOG(ERROR) << "StridedSlice support support dims less than " << input_shape_.size();
|
||||
MS_LOG(ERROR) << "StridedSlice support dims no more than " << MAX_DIMS << ", but the input shape is "
|
||||
<< input_shape_.size();
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
@ -255,28 +255,38 @@ def tensor_index_by_tensor(data, tensor_index):
|
|||
"the index tensor data type only support mstype.int32.")
|
||||
|
||||
|
||||
def _tensor_index_by_tuple_slice(data, t):
|
||||
def _tensor_index_by_tuple_slice(data, tuple_index):
|
||||
"""Tensor getitem by a tuple of slice"""
|
||||
shape = F.shape(data)
|
||||
if len(t) > len(shape):
|
||||
if len(tuple_index) > len(shape):
|
||||
const_utils.raise_index_error("When tensor is indexed by a tuple, "
|
||||
"the length of the tuple cannot be greater than the dimension of the tensor.")
|
||||
begin_strides, end_strides, step_strides, shrink_axis_mask = \
|
||||
const_utils.get_stride_info_from_tuple(shape, t)
|
||||
const_utils.get_stride_info_from_tuple(shape, tuple_index)
|
||||
return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides)
|
||||
|
||||
|
||||
def tensor_expand_dims(data, tuple_index):
|
||||
"""Expand tensor dims by tuple contains None and replace the None by slice in tuple_index """
|
||||
none_positions, tuple_index_without_none = const_utils.split_tuple_index_for_none(tuple_index)
|
||||
for position in none_positions:
|
||||
data = F.expand_dims(data, position)
|
||||
return data, tuple_index_without_none
|
||||
|
||||
|
||||
def tensor_index_by_tuple(data, tuple_index):
|
||||
"""Tensor getitem by tuple of various types"""
|
||||
"""Tensor getitem by tuple of various types with None"""
|
||||
# data, tuple_index_without_none = tensor_expand_dims(data, tuple_index)
|
||||
tuple_index_without_none = tuple_index
|
||||
if len(tuple_index) == 1:
|
||||
return data[tuple_index[0]]
|
||||
indexes_types = hyper_map(F.typeof, tuple_index)
|
||||
index_elements_type = const_utils.tuple_index_tensor_cnt(indexes_types, const_utils.TENSOR_GETITEM)
|
||||
if index_elements_type == const_utils.NO_TENSOR:
|
||||
return _tensor_index_by_tuple_slice(data, tuple_index)
|
||||
if index_elements_type == const_utils.ALL_TENSOR:
|
||||
return _tensor_getitem_by_tuple_of_tensor(data, tuple_index)
|
||||
return _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index)
|
||||
return data[tuple_index_without_none[0]]
|
||||
indexes_types = hyper_map(F.typeof, tuple_index_without_none)
|
||||
tensor_cnt = const_utils.tuple_index_tensor_cnt(indexes_types, const_utils.TENSOR_GETITEM)
|
||||
if tensor_cnt == const_utils.NO_TENSOR:
|
||||
return _tensor_index_by_tuple_slice(data, tuple_index_without_none)
|
||||
if tensor_cnt == const_utils.ALL_TENSOR:
|
||||
return _tensor_getitem_by_tuple_of_tensor(data, tuple_index_without_none)
|
||||
return _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index_without_none)
|
||||
|
||||
|
||||
def _tensor_setitem(self, index, value):
|
||||
|
|
|
@ -66,6 +66,19 @@ def check_equal(param1, param2, msg="{},{}"):
|
|||
return param1
|
||||
|
||||
|
||||
@constexpr
|
||||
def split_tuple_index_for_none(tuple_index):
|
||||
"""return the none_positions and the tuple_index_without_none whose None index is replaced by slice."""
|
||||
none_positions, tuple_index_without_none = (), ()
|
||||
for idx, item in enumerate(tuple_index):
|
||||
if item is None:
|
||||
none_positions += (idx,)
|
||||
tuple_index_without_none += (slice(None, None, None),)
|
||||
else:
|
||||
tuple_index_without_none += (item,)
|
||||
return none_positions, tuple_index_without_none
|
||||
|
||||
|
||||
@constexpr
|
||||
def check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size):
|
||||
"""Checks the shape and size of the sensor and value."""
|
||||
|
@ -75,35 +88,6 @@ def check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size):
|
|||
value_shape, data_shape))
|
||||
|
||||
|
||||
@constexpr
|
||||
def restrict_int_index(data_shape, tuple_indexes):
|
||||
"""
|
||||
Check the int index of tuple_indexes if value of index is out of the corresponding data shape
|
||||
and turn the negtive int index to positive int index.
|
||||
|
||||
Inputs:
|
||||
data_shape: the shape of data.
|
||||
tuple_indexes(tuple[mstype.int32]): the tuple of index which will be used in setitem or getitem.
|
||||
|
||||
Outputs:
|
||||
tuple_indexes_new(tuple[mstype.int32]): same purpose with tuple_indexes but only contain positive.
|
||||
"""
|
||||
if tuple_indexes is None:
|
||||
return tuple_indexes
|
||||
tuple_indexes_new = ()
|
||||
for i, index in enumerate(tuple_indexes):
|
||||
if isinstance(index, mstype.Int):
|
||||
if index < -data_shape[i] or index >= data_shape[i]:
|
||||
raise_index_error("The index is out of the data's special dimension range.")
|
||||
elif index < 0:
|
||||
tuple_indexes_new += (tuple_indexes[i]+data_shape[i],)
|
||||
else:
|
||||
tuple_indexes_new += (tuple_indexes[i],)
|
||||
else:
|
||||
tuple_indexes_new += (tuple_indexes[i],)
|
||||
return tuple_indexes_new
|
||||
|
||||
|
||||
@constexpr
|
||||
def check_tensor_setitem_index(index, element_type=None):
|
||||
"""Checks tuple index type of tensor assignment."""
|
||||
|
|
|
@ -213,7 +213,7 @@ def _tensor_getitem_by_tuple(data, tuple_index):
|
|||
|
||||
Inputs:
|
||||
data (Tensor): A tensor.
|
||||
tuple_index (tuple): Index in tuple.
|
||||
tuple_index (tuple): Index in tuple which include ellipsis, slice, int, Tensor, None, list, tuple.
|
||||
|
||||
Outputs:
|
||||
Tensor, element type is the same as the element type of data.
|
||||
|
|
Loading…
Reference in New Issue