!35964 [MS][LITE] UnsortedSegmentSum ops dynamic shape support

Merge pull request !35964 from luoyuan/UnsortedSegmentSum-dynamic-shape
This commit is contained in:
i-robot 2022-07-12 10:34:57 +00:00 committed by Gitee
commit 260fa2d43e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
13 changed files with 253 additions and 50 deletions

View File

@ -379,6 +379,7 @@ Array操作
mindspore.ops.unique
mindspore.ops.unique_consecutive
mindspore.ops.unique_with_pad
mindspore.ops.unsorted_segment_sum
.. list-table::
:widths: 50 50

View File

@ -0,0 +1,37 @@
mindspore.ops.unsorted_segment_sum
================================
.. py:function:: mindspore.ops.unsorted_segment_sum(input_x, segment_ids, num_segments)
沿分段计算输入Tensor元素的和。
计算输出Tensor :math:`\text{output}[i] = \sum_{segment\_ids[j] == i} \text{data}[j, \ldots]` ,其中 :math:`j,...` 是代表元素索引的Tuple。 `segment_ids` 确定输入Tensor元素的分段。 `segment_ids` 不需要排序,也不需要覆盖 `num_segments` 范围内的所有值。
UnsortedSegmentSum的计算过程如下图所示
.. image:: UnsortedSegmentSum.png
.. note::
- 如果 `segment_ids` 中不存在segment_id `i` ,则对输出 `output[i]` 填充0。
- 在Ascend平台上如果segment_id的值小于0或大于输入Tensor的shape的长度将触发执行错误。
如果给定的segment_ids :math: `i` 的和为空math: `\text{output}[i] = 0` 。如果 `segment_ids` 元素为负数,将忽略该值。 `num_segments` 必须等于不同segment_id的数量。
**参数:**
- **input_x** (Tensor)- shape :math:`(x_1, x_2, ..., x_R)`
- **segment_ids** (Tensor) - 将形状设置为 :math:`(x_1, x_2, ..., x_N)` 其中0<N<=R。
- **num_segments** (Union[int, Tensor], optional) - 分段数量 :math:`z` 数据类型为int或0维的Tensor。
**返回:**
Tensorshape :math:`(z, x_{N+1}, ..., x_R)`
**异常:**
- **TypeError** - `num_segments` 不是int类型或Tensor的dtype不是int32或int64。
- **ValueError** - `segment_ids` 的维度小于1。
- **ValueError** - `segment_ids` 的维度大于`input_x`的维度。
- **ValueError** - `num_segments` int类型的值小于0或Tensor类型的元素小于0。
- **ValueError** - `num_segments` 是长度不等于0的Tensor。

View File

@ -378,6 +378,7 @@ Array Operation
mindspore.ops.unique_consecutive
mindspore.ops.unique_with_pad
mindspore.ops.gumbel_softmax
mindspore.ops.unsorted_segment_sum
.. list-table::
:widths: 50 50

View File

@ -533,13 +533,15 @@ void TbeJsonCreator::GenInputConstValue(const AnfNodePtr &anf_node, size_t real_
auto value_node = input_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto value = value_node->value();
MS_EXCEPTION_IF_NULL(value);
if (value) {
ParseConstValue(value, input_desc);
}
} else if (input_node->isa<Parameter>()) {
auto param = input_node->cast<ParameterPtr>();
auto value = param->default_param();
MS_EXCEPTION_IF_NULL(value);
if (value) {
ParseConstValue(value, input_desc);
}
} else {
MS_LOG(ERROR) << "The operator " << anf_node->fullname_with_scope() << "'s input" << real_input_index
<< "'s value depend is " << value_depend << ", but its input node is a " << input_node->type_name()

View File

@ -360,6 +360,7 @@ constexpr const char kNameStringLength[] = "StringLength";
constexpr const char kNameGetShape[] = "GetShape";
constexpr const char kNameKlDivLossGrad[] = "KLDivLossGrad";
constexpr const char kNameRandomStandardNormal[] = "RandomStandardNormal";
constexpr const char kNameUnsortedSegmentSum[] = "UnsortedSegmentSum";
class OpAdapterDesc;

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019-2021 Huawei Technologies Co., Ltd
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -166,11 +166,10 @@ OUTPUT_MAP(StridedSliceV2) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(StridedSliceV2, kNameStridedSliceV2, ADPT_DESC(StridedSliceV2))
// UnsortedSegmentSum
INPUT_MAP(UnsortedSegmentSumD) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(segment_ids)}};
INPUT_ATTR_MAP(UnsortedSegmentSumD) = {{3, ATTR_DESC(num_segments, AnyTraits<int64_t>())}};
ATTR_MAP(UnsortedSegmentSumD) = EMPTY_ATTR_MAP;
OUTPUT_MAP(UnsortedSegmentSumD) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(UnsortedSegmentSumD, prim::kPrimUnsortedSegmentSum->name(), ADPT_DESC(UnsortedSegmentSumD))
INPUT_MAP(UnsortedSegmentSum) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(segment_ids)}, {3, INPUT_DESC(num_segments)}};
ATTR_MAP(UnsortedSegmentSum) = EMPTY_ATTR_MAP;
OUTPUT_MAP(UnsortedSegmentSum) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(UnsortedSegmentSum, prim::kPrimUnsortedSegmentSum->name(), ADPT_DESC(UnsortedSegmentSum))
// UnsortedSegmentProdD
INPUT_MAP(UnsortedSegmentProdD) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(segment_ids)}};

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019-2021 Huawei Technologies Co., Ltd
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -56,9 +56,8 @@ DECLARE_OP_USE_OUTPUT(StridedSlice)
DECLARE_OP_ADAPTER(StridedSliceV2)
DECLARE_OP_USE_OUTPUT(StridedSliceV2)
DECLARE_OP_ADAPTER(UnsortedSegmentSumD)
DECLARE_OP_USE_INPUT_ATTR(UnsortedSegmentSumD)
DECLARE_OP_USE_OUTPUT(UnsortedSegmentSumD)
DECLARE_OP_ADAPTER(UnsortedSegmentSum)
DECLARE_OP_USE_OUTPUT(UnsortedSegmentSum)
DECLARE_OP_ADAPTER(UnsortedSegmentProdD)
DECLARE_OP_USE_INPUT_ATTR(UnsortedSegmentProdD)

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -18,6 +18,7 @@
#include <algorithm>
#include <memory>
#include <set>
#include <map>
#include <vector>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
@ -26,7 +27,134 @@
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr UnsortedSegmentSumInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const std::string &op_name = primitive->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto x_min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kMinShape];
auto x_max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kMaxShape];
CheckAndConvertUtils::CheckMinMaxShape(x_shape, &x_min_shape, &x_max_shape);
auto x_shape_rank = SizeToLong(x_shape.size());
(void)CheckAndConvertUtils::CheckInteger("input_x size", x_shape_rank, kGreaterThan, 0, op_name);
auto segment_ids_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto segment_ids_shape_rank = SizeToLong(segment_ids_shape.size());
(void)CheckAndConvertUtils::CheckInteger("segment_ids size", segment_ids_shape_rank, kGreaterThan, 0, op_name);
ShapeVector output_shape;
ShapeVector out_min_shape;
ShapeVector out_max_shape;
constexpr int dynamic_rank_len = 1;
constexpr int dynamic_rank_value = -2;
if ((x_shape_rank == dynamic_rank_len && x_shape[0] == dynamic_rank_value) ||
(segment_ids_shape_rank == dynamic_rank_len && segment_ids_shape[0] == dynamic_rank_value)) {
output_shape = {dynamic_rank_value}; // unknown dimension
out_min_shape = {0};
out_max_shape = {abstract::Shape::SHP_ANY};
return std::make_shared<abstract::Shape>(output_shape, out_min_shape, out_max_shape);
}
(void)CheckAndConvertUtils::CheckValue<size_t>("x rank", x_shape.size(), kGreaterEqual, "segment_ids_shape rank",
segment_ids_shape.size(), op_name);
for (uint64_t i = 0; i < segment_ids_shape.size(); i++) {
if (segment_ids_shape[i] == abstract::Shape::SHP_ANY || x_shape[i] == abstract::Shape::SHP_ANY) continue;
if (segment_ids_shape[i] != x_shape[i]) {
MS_EXCEPTION(ValueError) << "For '" << op_name
<< "', the whose shape of 'segment_ids' must be a prefix of the shape of 'input_x', "
"but got 'segment_ids_shape["
<< i << "]': " << segment_ids_shape[i] << " and 'input_x_shape[" << i
<< "]': " << x_shape[i];
}
}
abstract::CheckShapeAnyAndPositive(op_name + " x_shape", x_shape);
abstract::CheckShapeAnyAndPositive(op_name + " segment_ids_shape", segment_ids_shape);
ShapeVector num_vec;
ShapeVector num_min_vec;
ShapeVector num_max_vec;
int64_t num_segments_v = 0;
// num_segments is a Tensor when UnsortedSegmentSum is a dynamic shape operator
if (input_args[kInputIndex2]->isa<abstract::AbstractTensor>()) {
auto n_value_ptr = input_args[kInputIndex2]->BuildValue();
MS_EXCEPTION_IF_NULL(n_value_ptr);
if (n_value_ptr->isa<tensor::Tensor>()) {
auto n_tensor_ptr = n_value_ptr->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(n_tensor_ptr);
num_segments_v = *static_cast<int64_t *>(n_tensor_ptr->data_c());
(void)CheckAndConvertUtils::CheckInteger("num_segments's value", num_segments_v, kGreaterThan, 0, op_name);
num_vec.push_back(num_segments_v);
num_min_vec.push_back(num_segments_v);
num_max_vec.push_back(num_segments_v);
} else {
auto n_abstract_tensor = input_args[kInputIndex2]->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(n_abstract_tensor);
num_vec.push_back(-1);
auto num_min_value = n_abstract_tensor->get_min_value();
auto num_max_value = n_abstract_tensor->get_max_value();
if (num_min_value != nullptr && num_max_value != nullptr) {
num_min_vec = GetValue<ShapeVector>(num_min_value);
num_max_vec = GetValue<ShapeVector>(num_max_value);
} else {
num_min_vec.push_back(-1);
num_max_vec.push_back(-1);
}
}
} else if (input_args[kInputIndex2]->isa<abstract::AbstractScalar>()) {
num_segments_v = GetValue<int64_t>(input_args[kInputIndex2]->BuildValue());
(void)CheckAndConvertUtils::CheckInteger("num_segments's value", num_segments_v, kGreaterThan, 0, op_name);
num_vec.push_back(num_segments_v);
num_min_vec.push_back(num_segments_v);
num_max_vec.push_back(num_segments_v);
} else {
MS_LOG(EXCEPTION) << "For '" << op_name
<< "', the third input type should be tensor or scalar, but got invalid abstract type:"
<< input_args[kInputIndex2]->type_name() << ".";
}
auto calc_shape = [segment_ids_shape_rank](const ShapeVector &num_vec, const ShapeVector &x_shape) -> ShapeVector {
ShapeVector out_vec;
(void)copy(num_vec.begin(), num_vec.end(), std::back_inserter(out_vec));
(void)copy(x_shape.begin() + segment_ids_shape_rank, x_shape.end(), std::back_inserter(out_vec));
return out_vec;
};
output_shape = calc_shape(num_vec, x_shape);
out_min_shape = calc_shape(num_min_vec, x_min_shape);
out_max_shape = calc_shape(num_max_vec, x_max_shape);
return std::make_shared<abstract::Shape>(output_shape, out_min_shape, out_max_shape);
}
TypePtr UnsortedSegmentSumInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
/* check segment_ids */
auto ids_ptr = input_args[kInputIndex1]->BuildType();
std::set<TypePtr> ids_type_set = {kInt32, kInt64};
(void)CheckAndConvertUtils::CheckTensorTypeValid("segment_ids type", ids_ptr, ids_type_set, prim_name);
/* check num_segments */
auto num_ptr = input_args[kInputIndex2]->BuildType();
std::map<std::string, TypePtr> args_num_segments;
(void)args_num_segments.insert({"num_segments", num_ptr});
const std::set<TypePtr> num_type_set = {kInt32, kInt64};
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_num_segments, num_type_set, prim_name);
/* check input_x */
auto x_type_ptr = input_args[kInputIndex0]->BuildType();
std::set<TypePtr> x_type_set = {kFloat16, kFloat32, kInt32};
return CheckAndConvertUtils::CheckTensorTypeValid("input_x", x_type_ptr, x_type_set, prim_name);
}
} // namespace
MIND_API_OPERATOR_IMPL(UnsortedSegmentSum, BaseOperator);
REGISTER_PRIMITIVE_C(kNameUnsortedSegmentSum, UnsortedSegmentSum);
AbstractBasePtr UnsortedSegmentSumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t kMinInputNum = 2;
const int64_t kMaxInputNum = 3;
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kMinInputNum, primitive->name());
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kLessEqual, kMaxInputNum, primitive->name());
auto infer_type = UnsortedSegmentSumInferType(primitive, input_args);
auto infer_shape = UnsortedSegmentSumInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_HOST_DEPENDS(kNameUnsortedSegmentSum, {2});
REGISTER_PRIMITIVE_EVAL_IMPL(UnsortedSegmentSum, prim::kPrimUnsortedSegmentSum, UnsortedSegmentSumInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -23,17 +23,19 @@ unsorted_segment_sum_ds_op_info = TBERegOp("UnsortedSegmentSum") \
.compute_cost(10) \
.kernel_name("unsorted_segment_sum") \
.partial_flag(True) \
.need_check_supported(True) \
.dynamic_compile_static(True) \
.dynamic_shape(True) \
.input(0, "x", False, "required", "all") \
.input(1, "segment_ids", False, "required", "all") \
.input(2, "num_segments", False, "required", "all") \
.input(1, "segment_ids", False, "required", "all", "optional") \
.input(2, "num_segments", False, "required", "all", "optional") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.is_dynamic_format(True) \
.dtype_format(DataType.None_None, DataType.None_None, DataType.None_None, DataType.None_None) \
.get_op_info()
@op_info_register(unsorted_segment_sum_ds_op_info)
def _unsorted_segment_sum_ds_tbe():
"""UnsortedSegmentSumUnknown TBE register"""
"""UnsortedSegmentSumD TBE register"""
return

View File

@ -97,6 +97,7 @@ from .array_func import (
affine_grid,
fills,
broadcast_to,
unsorted_segment_sum,
adaptive_max_pool2d,
col2im,
split,

View File

@ -80,6 +80,7 @@ zeros_like_ = P.ZerosLike()
cast_ = P.Cast()
tensor_select_ = P.Select()
index_fill_ = IndexFill()
unsorted_segment_sum_ = P.UnsortedSegmentSum()
@constexpr
@ -3738,6 +3739,62 @@ def min(x, axis=0, keep_dims=False):
argmin_with_value_op = P.ArgMinWithValue(axis, keep_dims)
return argmin_with_value_op(x)
def unsorted_segment_sum(input_x, segment_ids, num_segments):
r"""
Computes the sum of a tensor along segments.
Calculates a tensor such that :math:`\text{output}[i] = \sum_{segment\_ids[j] == i} \text{data}[j, \ldots]`, where
:math:`j` is a tuple describing the index of element in data. `segment_ids` selects which elements in data to sum
up. Segment_ids does not need to be sorted, and it does not need to cover all values in the entire valid value
range.
The following figure shows the calculation process of UnsortedSegmentSum:
.. image:: UnsortedSegmentSum.png
Note:
- If the segment_id i is absent in the segment_ids, then output[i] will be filled with 0.
- On Ascend, if the value of segment_id is less than 0 or greater than the length of the input data shape, an
execution error will occur.
If the sum of the given segment_ids :math:`i` is empty, then :math:`\text{output}[i] = 0`. If the given segment_ids
is negative, the value will be ignored. 'num_segments' must be equal to the number of different segment_ids.
Args:
- **input_x** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_R)`.
- **segment_ids** (Tensor) - Set the shape as :math:`(x_1, x_2, ..., x_N)`, where 0 < N <= R.
- **num_segments** (int) - Set :math:`z` as num_segments.
Returns:
Tensor, the shape is :math:`(z, x_{N+1}, ..., x_R)`.
Raises:
TypeError: If `num_segments` is not an int.
ValueError: If length of shape of `segment_ids` is less than 1.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> from mindspore import Tensor
>>> import mindspore.ops as ops
>>> input_x = Tensor([1, 2, 3, 4], mindspore.float32)
>>> segment_ids = Tensor([0, 0, 1, 2], mindspore.int32)
>>> num_segments = 4
>>> output = ops.UnsortedSegmentSum(input_x, segment_ids, num_segments)
>>> print(output)
[3. 3. 4. 0.]
>>> input_x = Tensor([1, 2, 3, 4, 2, 5], mindspore.float32)
>>> segment_ids = Tensor([0, 0, 1, 2, 3, 4], mindspore.int32)
>>> num_segments = 6
>>> output = ops.UnsortedSegmentSum(input_x, segment_ids, num_segments)
>>> print(output)
[3. 3. 4. 2. 5. 0.]
"""
return unsorted_segment_sum_(input_x, segment_ids, num_segments)
__all__ = [
'unique',
'unique_with_pad',
@ -3816,5 +3873,6 @@ __all__ = [
"index_fill",
'max',
'min',
'unsorted_segment_sum',
]
__all__.sort()

View File

@ -916,5 +916,6 @@ tensor_operator_registry.register('renorm', renorm)
tensor_operator_registry.register('adaptive_max_pool2d', AdaptiveMaxPool2D)
tensor_operator_registry.register('coalesce', coalesce)
tensor_operator_registry.register('arg_min_with_value', min)
tensor_operator_registry.register('unsorted_segment_sum', P.UnsortedSegmentSum)
__all__ = [name for name in dir() if name[0] != "_"]
__all__.remove('Primitive')

View File

@ -2406,34 +2406,7 @@ class UnsortedSegmentSum(PrimitiveWithInfer):
r"""
Computes the sum of a tensor along segments.
Calculates a tensor such that :math:`\text{output}[i] = \sum_{segment\_ids[j] == i} \text{data}[j, \ldots]`, where
:math:`j` is a tuple describing the index of element in data. `segment_ids` selects which elements in data to sum
up. Segment_ids does not need to be sorted, and it does not need to cover all values in the entire valid value
range.
The following figure shows the calculation process of UnsortedSegmentSum:
.. image:: UnsortedSegmentSum.png
Note:
- If the segment_id i is absent in the segment_ids, then output[i] will be filled with 0.
- On Ascend, if the value of segment_id is less than 0 or greater than the length of the input data shape, an
execution error will occur.
If the sum of the given segment_ids :math:`i` is empty, then :math:`\text{output}[i] = 0`. If the given segment_ids
is negative, the value will be ignored. 'num_segments' must be equal to the number of different segment_ids.
Inputs:
- **input_x** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_R)`.
- **segment_ids** (Tensor) - Set the shape as :math:`(x_1, x_2, ..., x_N)`, where 0 < N <= R.
- **num_segments** (int) - Set :math:`z` as num_segments.
Outputs:
Tensor, the shape is :math:`(z, x_{N+1}, ..., x_R)`.
Raises:
TypeError: If `num_segments` is not an int.
ValueError: If length of shape of `segment_ids` is less than 1.
Refer to :func:`mindspore.ops.unsorted_segment_sum` for more detail.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``