From 22a2f246c340072fde83ea751ec5f3d03dcc9184 Mon Sep 17 00:00:00 2001 From: Payne Date: Mon, 9 Nov 2020 15:07:35 +0800 Subject: [PATCH] add the func to expand dims by None for Tensor --- .../kernel_compiler/cpu/slice_cpu_kernel.cc | 3 +- .../gpu/arrays/strided_slice_gpu_kernel.h | 5 ++- .../composite/multitype_ops/_compile_utils.py | 34 +++++++++------ .../multitype_ops/_constexpr_utils.py | 42 ++++++------------- .../composite/multitype_ops/getitem_impl.py | 2 +- 5 files changed, 41 insertions(+), 45 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/slice_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/slice_cpu_kernel.cc index cbc0d42a25b..b4ba221e3d2 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/slice_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/slice_cpu_kernel.cc @@ -19,6 +19,7 @@ namespace mindspore { namespace kernel { +constexpr int MAX_DIMS = 8; void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) { CheckParam(kernel_node); input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); @@ -205,7 +206,7 @@ void SliceCPUKernel::CheckParam(const CNodePtr &kernel_node) const { MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but SliceCPUKernel needs 1 output."; } auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - if (input_shape.size() > 4) { + if (input_shape.size() > MAX_DIMS) { MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", but SliceCPUKernel olny support 4d or lower."; } if (input_shape.size() == 0) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_gpu_kernel.h index 60d2bd926fb..9eba2321971 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_gpu_kernel.h @@ -26,7 +26,7 @@ namespace mindspore { namespace kernel { -constexpr int MAX_DIMS = 7; +constexpr int MAX_DIMS = 8; template class StridedSliceGpuKernel : public GpuKernel { public: @@ -51,7 +51,8 @@ class StridedSliceGpuKernel : public GpuKernel { bool Init(const CNodePtr &kernel_node) override { input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); if (input_shape_.size() > MAX_DIMS) { - MS_LOG(ERROR) << "StridedSlice support support dims less than " << input_shape_.size(); + MS_LOG(ERROR) << "StridedSlice support dims no more than " << MAX_DIMS << ", but the input shape is " + << input_shape_.size(); return false; } diff --git a/mindspore/ops/composite/multitype_ops/_compile_utils.py b/mindspore/ops/composite/multitype_ops/_compile_utils.py index f592eba7b30..e184eade8e9 100644 --- a/mindspore/ops/composite/multitype_ops/_compile_utils.py +++ b/mindspore/ops/composite/multitype_ops/_compile_utils.py @@ -255,28 +255,38 @@ def tensor_index_by_tensor(data, tensor_index): "the index tensor data type only support mstype.int32.") -def _tensor_index_by_tuple_slice(data, t): +def _tensor_index_by_tuple_slice(data, tuple_index): """Tensor getitem by a tuple of slice""" shape = F.shape(data) - if len(t) > len(shape): + if len(tuple_index) > len(shape): const_utils.raise_index_error("When tensor is indexed by a tuple, " "the length of the tuple cannot be greater than the dimension of the tensor.") begin_strides, end_strides, step_strides, shrink_axis_mask = \ - const_utils.get_stride_info_from_tuple(shape, t) + const_utils.get_stride_info_from_tuple(shape, tuple_index) return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides) +def tensor_expand_dims(data, tuple_index): + """Expand tensor dims by tuple contains None and replace the None by slice in tuple_index """ + none_positions, tuple_index_without_none = const_utils.split_tuple_index_for_none(tuple_index) + for position in none_positions: + data = F.expand_dims(data, position) + return data, tuple_index_without_none + + def tensor_index_by_tuple(data, tuple_index): - """Tensor getitem by tuple of various types""" + """Tensor getitem by tuple of various types with None""" + # data, tuple_index_without_none = tensor_expand_dims(data, tuple_index) + tuple_index_without_none = tuple_index if len(tuple_index) == 1: - return data[tuple_index[0]] - indexes_types = hyper_map(F.typeof, tuple_index) - index_elements_type = const_utils.tuple_index_tensor_cnt(indexes_types, const_utils.TENSOR_GETITEM) - if index_elements_type == const_utils.NO_TENSOR: - return _tensor_index_by_tuple_slice(data, tuple_index) - if index_elements_type == const_utils.ALL_TENSOR: - return _tensor_getitem_by_tuple_of_tensor(data, tuple_index) - return _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index) + return data[tuple_index_without_none[0]] + indexes_types = hyper_map(F.typeof, tuple_index_without_none) + tensor_cnt = const_utils.tuple_index_tensor_cnt(indexes_types, const_utils.TENSOR_GETITEM) + if tensor_cnt == const_utils.NO_TENSOR: + return _tensor_index_by_tuple_slice(data, tuple_index_without_none) + if tensor_cnt == const_utils.ALL_TENSOR: + return _tensor_getitem_by_tuple_of_tensor(data, tuple_index_without_none) + return _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index_without_none) def _tensor_setitem(self, index, value): diff --git a/mindspore/ops/composite/multitype_ops/_constexpr_utils.py b/mindspore/ops/composite/multitype_ops/_constexpr_utils.py index 1f9fa4c24be..b6f29aeeb3a 100644 --- a/mindspore/ops/composite/multitype_ops/_constexpr_utils.py +++ b/mindspore/ops/composite/multitype_ops/_constexpr_utils.py @@ -66,6 +66,19 @@ def check_equal(param1, param2, msg="{},{}"): return param1 +@constexpr +def split_tuple_index_for_none(tuple_index): + """return the none_positions and the tuple_index_without_none whose None index is replaced by slice.""" + none_positions, tuple_index_without_none = (), () + for idx, item in enumerate(tuple_index): + if item is None: + none_positions += (idx,) + tuple_index_without_none += (slice(None, None, None),) + else: + tuple_index_without_none += (item,) + return none_positions, tuple_index_without_none + + @constexpr def check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size): """Checks the shape and size of the sensor and value.""" @@ -75,35 +88,6 @@ def check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size): value_shape, data_shape)) -@constexpr -def restrict_int_index(data_shape, tuple_indexes): - """ - Check the int index of tuple_indexes if value of index is out of the corresponding data shape - and turn the negtive int index to positive int index. - - Inputs: - data_shape: the shape of data. - tuple_indexes(tuple[mstype.int32]): the tuple of index which will be used in setitem or getitem. - - Outputs: - tuple_indexes_new(tuple[mstype.int32]): same purpose with tuple_indexes but only contain positive. - """ - if tuple_indexes is None: - return tuple_indexes - tuple_indexes_new = () - for i, index in enumerate(tuple_indexes): - if isinstance(index, mstype.Int): - if index < -data_shape[i] or index >= data_shape[i]: - raise_index_error("The index is out of the data's special dimension range.") - elif index < 0: - tuple_indexes_new += (tuple_indexes[i]+data_shape[i],) - else: - tuple_indexes_new += (tuple_indexes[i],) - else: - tuple_indexes_new += (tuple_indexes[i],) - return tuple_indexes_new - - @constexpr def check_tensor_setitem_index(index, element_type=None): """Checks tuple index type of tensor assignment.""" diff --git a/mindspore/ops/composite/multitype_ops/getitem_impl.py b/mindspore/ops/composite/multitype_ops/getitem_impl.py index fedbc1fa3fe..3194981845f 100644 --- a/mindspore/ops/composite/multitype_ops/getitem_impl.py +++ b/mindspore/ops/composite/multitype_ops/getitem_impl.py @@ -213,7 +213,7 @@ def _tensor_getitem_by_tuple(data, tuple_index): Inputs: data (Tensor): A tensor. - tuple_index (tuple): Index in tuple. + tuple_index (tuple): Index in tuple which include ellipsis, slice, int, Tensor, None, list, tuple. Outputs: Tensor, element type is the same as the element type of data.