From fa8b473a24eaa42b9b2a4aca8b365990dbcdf4bd Mon Sep 17 00:00:00 2001 From: qinzheng Date: Thu, 9 Jun 2022 16:27:45 +0800 Subject: [PATCH] support infershape tensor functional api of split --- mindspore/ccsrc/pipeline/jit/resource.cc | 1 + mindspore/core/ops/split.cc | 63 +++++++++++++++++++ .../_extends/parse/standard_method.py | 8 +++ mindspore/python/mindspore/common/tensor.py | 38 +++++++++++ .../python/mindspore/ops/function/__init__.py | 1 + .../mindspore/ops/function/array_func.py | 55 +++++++++++++++- mindspore/python/mindspore/ops/functional.py | 1 + .../mindspore/ops/operations/array_ops.py | 19 +----- 8 files changed, 167 insertions(+), 19 deletions(-) diff --git a/mindspore/ccsrc/pipeline/jit/resource.cc b/mindspore/ccsrc/pipeline/jit/resource.cc index 8c515a72da6..c91553bc7a2 100644 --- a/mindspore/ccsrc/pipeline/jit/resource.cc +++ b/mindspore/ccsrc/pipeline/jit/resource.cc @@ -263,6 +263,7 @@ BuiltInTypeMap &GetMethodMap() { {"to_coo", std::string("to_coo")}, // dense_to_sparse_coo() {"to_csr", std::string("to_csr")}, // dense_to_sparse_csr() {"col2im", std::string("col2im")}, // P.Col2Im + {"split", std::string("split")}, // P.Split() }}, {kObjectTypeRowTensorType, { diff --git a/mindspore/core/ops/split.cc b/mindspore/core/ops/split.cc index fd2ba0f44b9..1589abe861d 100644 --- a/mindspore/core/ops/split.cc +++ b/mindspore/core/ops/split.cc @@ -49,6 +49,69 @@ int64_t Split::get_output_num() const { return GetValue(value_ptr); } +namespace { +abstract::TupleShapePtr SplitInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto prim_name = primitive->name(); + (void)CheckAndConvertUtils::CheckInteger("input num", SizeToLong(input_args.size()), kEqual, 1L, prim_name); + auto shape_ptr = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, 0); + MS_EXCEPTION_IF_NULL(shape_ptr); + auto input_shape = shape_ptr->shape(); + auto input_min_shape = shape_ptr->min_shape(); + auto input_max_shape = shape_ptr->max_shape(); + + auto input_rank = SizeToLong(input_shape.size()); + auto output_num = GetValue(primitive->GetAttr(kOutputNum)); + auto axis = GetValue(primitive->GetAttr(kAxis)); + (void)CheckAndConvertUtils::CheckInteger("input_rank", input_rank, kGreaterEqual, 1, prim_name); + + ShapeVector out_shape = input_shape; + ShapeVector out_min_shape = input_min_shape; + ShapeVector out_max_shape = input_max_shape; + if (!shape_ptr->IsDimUnknown()) { + (void)CheckAndConvertUtils::CheckInteger(kAxis, axis, kLessThan, input_rank, prim_name); + axis = axis < 0 ? axis + input_rank : axis; + int64_t split_axis_length = -1; + if (input_shape[axis] != -1) { + (void)CheckAndConvertUtils::CheckInteger("input x", input_shape[axis] % output_num, kEqual, 0L, prim_name); + split_axis_length = input_shape[axis] / output_num; + } + out_shape[axis] = split_axis_length; + } + std::vector shape_tuple; + for (int64_t i = 0; i < output_num; i++) { + abstract::ShapePtr output_shape = std::make_shared(out_shape, out_min_shape, out_max_shape); + shape_tuple.push_back(output_shape); + } + return std::make_shared(shape_tuple); +} + +TuplePtr SplitInferType(const PrimitivePtr &prim, const std::vector &input_args) { + for (const auto &item : input_args) { + MS_EXCEPTION_IF_NULL(item); + } + auto output_num = GetValue(prim->GetAttr(kOutputNum)); + auto infer_type = input_args[0]->BuildType(); + MS_EXCEPTION_IF_NULL(infer_type); + const std::set valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, + kUInt16, kUInt32, kUInt64, kFloat16, kFloat32}; + auto type = CheckAndConvertUtils::CheckTensorTypeValid("input_x", infer_type, valid_types, prim->name()); + std::vector type_tuple; + for (int64_t i = 0; i < output_num; i++) { + type_tuple.push_back(type); + } + return std::make_shared(type_tuple); +} +} // namespace + +AbstractBasePtr SplitInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto infertype = SplitInferType(primitive, input_args); + auto infershape = SplitInferShape(primitive, input_args); + return abstract::MakeAbstract(infershape, infertype); +} + REGISTER_PRIMITIVE_C(kNameSplit, Split); } // namespace ops } // namespace mindspore diff --git a/mindspore/python/mindspore/_extends/parse/standard_method.py b/mindspore/python/mindspore/_extends/parse/standard_method.py index 0c2cfb69681..c3cdaed9a38 100644 --- a/mindspore/python/mindspore/_extends/parse/standard_method.py +++ b/mindspore/python/mindspore/_extends/parse/standard_method.py @@ -2359,3 +2359,11 @@ def pdist(x, p=2.0): Refer to :func:`mindspore.ops.pdist` for more detail. """ return F.pdist(x, p=p) + + +def split(input_x, axis=0, output_num=1): + """ + Splits the input tensor into output_num of tensors along the given axis and output numbers. + Refer to :func:`mindspore.ops.split` for more detail. + """ + return F.split(input_x, axis, output_num) diff --git a/mindspore/python/mindspore/common/tensor.py b/mindspore/python/mindspore/common/tensor.py index e0b1ce9745f..49f62418ac6 100644 --- a/mindspore/python/mindspore/common/tensor.py +++ b/mindspore/python/mindspore/common/tensor.py @@ -4253,6 +4253,44 @@ class Tensor(Tensor_): validator.check_int(len(output_size), 2, Rel.EQ, "length of output_size", Tensor) return tensor_operator_registry.get("adaptive_avgpool2d")(self, output_size) + def split(self, axis=0, output_num=1): + """ + Splits the input tensor into output_num of tensors along the given axis and output numbers. + + The `input_x` tensor will be split into equally sized sub-tensors. + This requires that `input_x.shape(axis)` is divisible by `output_num`. + + Args: + axis (int): Index of the split position. Default: 0. + output_num (int): The number of output tensors. Must be positive int. Default: 1. + + Returns: + tuple[Tensor], the shape of each output tensor is the same, which is + :math:`(y_1, y_2, ..., y_S)`. And the data type is the same with `input_x`. + + Raises: + TypeError: If `axis` or `output_num` is not an int. + ValueError: If `axis` is out of the range [-len(`input_x.shape`), len(`input_x.shape`)), + or if the `output_num` is less than or equal to 0. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> x = Tensor(np.array([[1, 1, 1, 1], [2, 2, 2, 2]]), mindspore.int32) + >>> print(x) + [[1 1 1 1] + [2 2 2 2]] + >>> output = x.split(1, 2) + >>> print(output) + (Tensor(shape=[2, 2], dtype=Int32, value= + [[1, 1], + [2, 2]]), Tensor(shape=[2, 2], dtype=Int32, value= + [[1, 1], + [2, 2]])) + """ + return tensor_operator_registry.get('split')(axis, output_num)(self) + class RowTensor(RowTensor_): """ diff --git a/mindspore/python/mindspore/ops/function/__init__.py b/mindspore/python/mindspore/ops/function/__init__.py index 36d17ca6c58..a285ecf0249 100644 --- a/mindspore/python/mindspore/ops/function/__init__.py +++ b/mindspore/python/mindspore/ops/function/__init__.py @@ -89,6 +89,7 @@ from .array_func import ( broadcast_to, adaptive_max_pool2d, col2im, + split ) from .parameter_func import ( assign, diff --git a/mindspore/python/mindspore/ops/function/array_func.py b/mindspore/python/mindspore/ops/function/array_func.py index ddfb0fc3fb3..8e9b8a5ce09 100644 --- a/mindspore/python/mindspore/ops/function/array_func.py +++ b/mindspore/python/mindspore/ops/function/array_func.py @@ -2991,6 +2991,58 @@ def col2im(input_x, output_size, kernel_size, dilation, padding_value, stride): return c2i(input_x, output_size) +def split(input_x, axis=0, output_num=1): + r""" + Splits the input tensor into output_num of tensors along the given axis and output numbers. + + The `input_x` tensor will be split into equally sized sub-tensors. + This requires that `input_x.shape(axis)` is divisible by `output_num`. + + Args: + input_x (Tensor): The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. + axis (int): Index of the split position. Default: 0. + output_num (int): The number of output tensors. Must be positive int. Default: 1. + + Returns: + tuple[Tensor], the shape of each output tensor is the same, which is + :math:`(y_1, y_2, ..., y_S)`. And the data type is the same with `input_x`. + + Raises: + TypeError: If `axis` or `output_num` is not an int. + ValueError: If `axis` is out of the range [-len(`input_x.shape`), len(`input_x.shape`)), + or if the `output_num` is less than or equal to 0. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> x = Tensor(np.array([[1, 1, 1, 1], [2, 2, 2, 2]]), mindspore.int32) + >>> print(x) + [[1 1 1 1] + [2 2 2 2]] + >>> output = ops.split(x, 1, 2) + >>> print(output) + (Tensor(shape=[2, 2], dtype=Int32, value= + [[1, 1], + [2, 2]]), Tensor(shape=[2, 2], dtype=Int32, value= + [[1, 1], + [2, 2]])) + >>> output = ops.split(x, 1, 4) + >>> print(output) + (Tensor(shape=[2, 1], dtype=Int32, value= + [[1], + [2]]), Tensor(shape=[2, 1], dtype=Int32, value= + [[1], + [2]]), Tensor(shape=[2, 1], dtype=Int32, value= + [[1], + [2]]), Tensor(shape=[2, 1], dtype=Int32, value= + [[1], + [2]])) + """ + split_ = P.Split(axis, output_num) + return split_(input_x) + + __all__ = [ 'unique', 'unique_consecutive', @@ -3054,6 +3106,7 @@ __all__ = [ 'adaptive_max_pool2d', 'meshgrid', 'broadcast_to', - 'col2im' + 'col2im', + 'split' ] __all__.sort() diff --git a/mindspore/python/mindspore/ops/functional.py b/mindspore/python/mindspore/ops/functional.py index f19455458a1..b103c83a241 100644 --- a/mindspore/python/mindspore/ops/functional.py +++ b/mindspore/python/mindspore/ops/functional.py @@ -979,6 +979,7 @@ tensor_operator_registry.register('inplace_sub', P.InplaceSub) tensor_operator_registry.register('adaptive_avgpool2d', adaptive_avgpool2d) tensor_operator_registry.register('col2im', col2im) tensor_operator_registry.register('standard_laplace', P.StandardLaplace) +tensor_operator_registry.register('split', P.Split) # ms cannot support Tensor(True) compare tensor_operator_registry.register('__eq__', equal) tensor_operator_registry.register('__ne__', not_equal) diff --git a/mindspore/python/mindspore/ops/operations/array_ops.py b/mindspore/python/mindspore/ops/operations/array_ops.py index 590d961c189..207a840a33b 100755 --- a/mindspore/python/mindspore/ops/operations/array_ops.py +++ b/mindspore/python/mindspore/ops/operations/array_ops.py @@ -1181,24 +1181,7 @@ class Split(PrimitiveWithCheck): """ Splits the input tensor into output_num of tensors along the given axis and output numbers. - The `input_x` tensor will be split into equally sized sub-tensors. - This requires that `input_x.shape(axis)` is divisible by `output_num`. - - Args: - axis (int): Index of the split position. Default: 0. - output_num (int): The number of output tensors. Must be positive int. Default: 1. - - Inputs: - - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. - - Outputs: - tuple[Tensor], the shape of each output tensor is the same, which is - :math:`(y_1, y_2, ..., y_S)`. And the data type is the same with `input_x`. - - Raises: - TypeError: If `axis` or `output_num` is not an int. - ValueError: If `axis` is out of the range [-len(`input_x.shape`), len(`input_x.shape`)), - or if the `output_num` is less than or equal to 0. + Refer to :func:`mindspore.ops.split` for more detail. Supported Platforms: ``Ascend`` ``GPU`` ``CPU``