!35964 [MS][LITE] UnsortedSegmentSum ops dynamic shape support
Merge pull request !35964 from luoyuan/UnsortedSegmentSum-dynamic-shape
This commit is contained in:
commit
260fa2d43e
|
@ -379,6 +379,7 @@ Array操作
|
|||
mindspore.ops.unique
|
||||
mindspore.ops.unique_consecutive
|
||||
mindspore.ops.unique_with_pad
|
||||
mindspore.ops.unsorted_segment_sum
|
||||
|
||||
.. list-table::
|
||||
:widths: 50 50
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
mindspore.ops.unsorted_segment_sum
|
||||
================================
|
||||
|
||||
.. py:function:: mindspore.ops.unsorted_segment_sum(input_x, segment_ids, num_segments)
|
||||
|
||||
沿分段计算输入Tensor元素的和。
|
||||
|
||||
计算输出Tensor :math:`\text{output}[i] = \sum_{segment\_ids[j] == i} \text{data}[j, \ldots]` ,其中 :math:`j,...` 是代表元素索引的Tuple。 `segment_ids` 确定输入Tensor元素的分段。 `segment_ids` 不需要排序,也不需要覆盖 `num_segments` 范围内的所有值。
|
||||
|
||||
UnsortedSegmentSum的计算过程如下图所示:
|
||||
|
||||
.. image:: UnsortedSegmentSum.png
|
||||
|
||||
.. note::
|
||||
- 如果 `segment_ids` 中不存在segment_id `i` ,则对输出 `output[i]` 填充0。
|
||||
- 在Ascend平台上,如果segment_id的值小于0或大于输入Tensor的shape的长度,将触发执行错误。
|
||||
|
||||
如果给定的segment_ids :math: `i` 的和为空,则:math: `\text{output}[i] = 0` 。如果 `segment_ids` 元素为负数,将忽略该值。 `num_segments` 必须等于不同segment_id的数量。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **input_x** (Tensor)- shape: :math:`(x_1, x_2, ..., x_R)` 。
|
||||
- **segment_ids** (Tensor) - 将形状设置为 :math:`(x_1, x_2, ..., x_N)` ,其中0<N<=R。
|
||||
- **num_segments** (Union[int, Tensor], optional) - 分段数量 :math:`z` ,数据类型为int或0维的Tensor。
|
||||
|
||||
**返回:**
|
||||
|
||||
Tensor,shape: :math:`(z, x_{N+1}, ..., x_R)` 。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** - `num_segments` 不是int类型或Tensor的dtype不是int32或int64。
|
||||
- **ValueError** - `segment_ids` 的维度小于1。
|
||||
- **ValueError** - `segment_ids` 的维度大于`input_x`的维度。
|
||||
- **ValueError** - `num_segments` int类型的值小于0或Tensor类型的元素小于0。
|
||||
- **ValueError** - `num_segments` 是长度不等于0的Tensor。
|
||||
|
|
@ -378,6 +378,7 @@ Array Operation
|
|||
mindspore.ops.unique_consecutive
|
||||
mindspore.ops.unique_with_pad
|
||||
mindspore.ops.gumbel_softmax
|
||||
mindspore.ops.unsorted_segment_sum
|
||||
|
||||
.. list-table::
|
||||
:widths: 50 50
|
||||
|
|
|
@ -533,13 +533,15 @@ void TbeJsonCreator::GenInputConstValue(const AnfNodePtr &anf_node, size_t real_
|
|||
auto value_node = input_node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
auto value = value_node->value();
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
ParseConstValue(value, input_desc);
|
||||
if (value) {
|
||||
ParseConstValue(value, input_desc);
|
||||
}
|
||||
} else if (input_node->isa<Parameter>()) {
|
||||
auto param = input_node->cast<ParameterPtr>();
|
||||
auto value = param->default_param();
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
ParseConstValue(value, input_desc);
|
||||
if (value) {
|
||||
ParseConstValue(value, input_desc);
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "The operator " << anf_node->fullname_with_scope() << "'s input" << real_input_index
|
||||
<< "'s value depend is " << value_depend << ", but its input node is a " << input_node->type_name()
|
||||
|
|
|
@ -360,6 +360,7 @@ constexpr const char kNameStringLength[] = "StringLength";
|
|||
constexpr const char kNameGetShape[] = "GetShape";
|
||||
constexpr const char kNameKlDivLossGrad[] = "KLDivLossGrad";
|
||||
constexpr const char kNameRandomStandardNormal[] = "RandomStandardNormal";
|
||||
constexpr const char kNameUnsortedSegmentSum[] = "UnsortedSegmentSum";
|
||||
|
||||
class OpAdapterDesc;
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -166,11 +166,10 @@ OUTPUT_MAP(StridedSliceV2) = {{0, OUTPUT_DESC(y)}};
|
|||
REG_ADPT_DESC(StridedSliceV2, kNameStridedSliceV2, ADPT_DESC(StridedSliceV2))
|
||||
|
||||
// UnsortedSegmentSum
|
||||
INPUT_MAP(UnsortedSegmentSumD) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(segment_ids)}};
|
||||
INPUT_ATTR_MAP(UnsortedSegmentSumD) = {{3, ATTR_DESC(num_segments, AnyTraits<int64_t>())}};
|
||||
ATTR_MAP(UnsortedSegmentSumD) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(UnsortedSegmentSumD) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(UnsortedSegmentSumD, prim::kPrimUnsortedSegmentSum->name(), ADPT_DESC(UnsortedSegmentSumD))
|
||||
INPUT_MAP(UnsortedSegmentSum) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(segment_ids)}, {3, INPUT_DESC(num_segments)}};
|
||||
ATTR_MAP(UnsortedSegmentSum) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(UnsortedSegmentSum) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(UnsortedSegmentSum, prim::kPrimUnsortedSegmentSum->name(), ADPT_DESC(UnsortedSegmentSum))
|
||||
|
||||
// UnsortedSegmentProdD
|
||||
INPUT_MAP(UnsortedSegmentProdD) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(segment_ids)}};
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -56,9 +56,8 @@ DECLARE_OP_USE_OUTPUT(StridedSlice)
|
|||
DECLARE_OP_ADAPTER(StridedSliceV2)
|
||||
DECLARE_OP_USE_OUTPUT(StridedSliceV2)
|
||||
|
||||
DECLARE_OP_ADAPTER(UnsortedSegmentSumD)
|
||||
DECLARE_OP_USE_INPUT_ATTR(UnsortedSegmentSumD)
|
||||
DECLARE_OP_USE_OUTPUT(UnsortedSegmentSumD)
|
||||
DECLARE_OP_ADAPTER(UnsortedSegmentSum)
|
||||
DECLARE_OP_USE_OUTPUT(UnsortedSegmentSum)
|
||||
|
||||
DECLARE_OP_ADAPTER(UnsortedSegmentProdD)
|
||||
DECLARE_OP_USE_INPUT_ATTR(UnsortedSegmentProdD)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -18,6 +18,7 @@
|
|||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
@ -26,7 +27,134 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr UnsortedSegmentSumInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const std::string &op_name = primitive->name();
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto x_min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kMinShape];
|
||||
auto x_max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kMaxShape];
|
||||
CheckAndConvertUtils::CheckMinMaxShape(x_shape, &x_min_shape, &x_max_shape);
|
||||
auto x_shape_rank = SizeToLong(x_shape.size());
|
||||
(void)CheckAndConvertUtils::CheckInteger("input_x size", x_shape_rank, kGreaterThan, 0, op_name);
|
||||
auto segment_ids_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
auto segment_ids_shape_rank = SizeToLong(segment_ids_shape.size());
|
||||
(void)CheckAndConvertUtils::CheckInteger("segment_ids size", segment_ids_shape_rank, kGreaterThan, 0, op_name);
|
||||
ShapeVector output_shape;
|
||||
ShapeVector out_min_shape;
|
||||
ShapeVector out_max_shape;
|
||||
constexpr int dynamic_rank_len = 1;
|
||||
constexpr int dynamic_rank_value = -2;
|
||||
if ((x_shape_rank == dynamic_rank_len && x_shape[0] == dynamic_rank_value) ||
|
||||
(segment_ids_shape_rank == dynamic_rank_len && segment_ids_shape[0] == dynamic_rank_value)) {
|
||||
output_shape = {dynamic_rank_value}; // unknown dimension
|
||||
out_min_shape = {0};
|
||||
out_max_shape = {abstract::Shape::SHP_ANY};
|
||||
return std::make_shared<abstract::Shape>(output_shape, out_min_shape, out_max_shape);
|
||||
}
|
||||
(void)CheckAndConvertUtils::CheckValue<size_t>("x rank", x_shape.size(), kGreaterEqual, "segment_ids_shape rank",
|
||||
segment_ids_shape.size(), op_name);
|
||||
for (uint64_t i = 0; i < segment_ids_shape.size(); i++) {
|
||||
if (segment_ids_shape[i] == abstract::Shape::SHP_ANY || x_shape[i] == abstract::Shape::SHP_ANY) continue;
|
||||
if (segment_ids_shape[i] != x_shape[i]) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << op_name
|
||||
<< "', the whose shape of 'segment_ids' must be a prefix of the shape of 'input_x', "
|
||||
"but got 'segment_ids_shape["
|
||||
<< i << "]': " << segment_ids_shape[i] << " and 'input_x_shape[" << i
|
||||
<< "]': " << x_shape[i];
|
||||
}
|
||||
}
|
||||
abstract::CheckShapeAnyAndPositive(op_name + " x_shape", x_shape);
|
||||
abstract::CheckShapeAnyAndPositive(op_name + " segment_ids_shape", segment_ids_shape);
|
||||
|
||||
ShapeVector num_vec;
|
||||
ShapeVector num_min_vec;
|
||||
ShapeVector num_max_vec;
|
||||
int64_t num_segments_v = 0;
|
||||
// num_segments is a Tensor when UnsortedSegmentSum is a dynamic shape operator
|
||||
if (input_args[kInputIndex2]->isa<abstract::AbstractTensor>()) {
|
||||
auto n_value_ptr = input_args[kInputIndex2]->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(n_value_ptr);
|
||||
if (n_value_ptr->isa<tensor::Tensor>()) {
|
||||
auto n_tensor_ptr = n_value_ptr->cast<tensor::TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(n_tensor_ptr);
|
||||
num_segments_v = *static_cast<int64_t *>(n_tensor_ptr->data_c());
|
||||
(void)CheckAndConvertUtils::CheckInteger("num_segments's value", num_segments_v, kGreaterThan, 0, op_name);
|
||||
num_vec.push_back(num_segments_v);
|
||||
num_min_vec.push_back(num_segments_v);
|
||||
num_max_vec.push_back(num_segments_v);
|
||||
} else {
|
||||
auto n_abstract_tensor = input_args[kInputIndex2]->cast<abstract::AbstractTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(n_abstract_tensor);
|
||||
num_vec.push_back(-1);
|
||||
auto num_min_value = n_abstract_tensor->get_min_value();
|
||||
auto num_max_value = n_abstract_tensor->get_max_value();
|
||||
if (num_min_value != nullptr && num_max_value != nullptr) {
|
||||
num_min_vec = GetValue<ShapeVector>(num_min_value);
|
||||
num_max_vec = GetValue<ShapeVector>(num_max_value);
|
||||
} else {
|
||||
num_min_vec.push_back(-1);
|
||||
num_max_vec.push_back(-1);
|
||||
}
|
||||
}
|
||||
} else if (input_args[kInputIndex2]->isa<abstract::AbstractScalar>()) {
|
||||
num_segments_v = GetValue<int64_t>(input_args[kInputIndex2]->BuildValue());
|
||||
(void)CheckAndConvertUtils::CheckInteger("num_segments's value", num_segments_v, kGreaterThan, 0, op_name);
|
||||
num_vec.push_back(num_segments_v);
|
||||
num_min_vec.push_back(num_segments_v);
|
||||
num_max_vec.push_back(num_segments_v);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "For '" << op_name
|
||||
<< "', the third input type should be tensor or scalar, but got invalid abstract type:"
|
||||
<< input_args[kInputIndex2]->type_name() << ".";
|
||||
}
|
||||
auto calc_shape = [segment_ids_shape_rank](const ShapeVector &num_vec, const ShapeVector &x_shape) -> ShapeVector {
|
||||
ShapeVector out_vec;
|
||||
(void)copy(num_vec.begin(), num_vec.end(), std::back_inserter(out_vec));
|
||||
(void)copy(x_shape.begin() + segment_ids_shape_rank, x_shape.end(), std::back_inserter(out_vec));
|
||||
return out_vec;
|
||||
};
|
||||
output_shape = calc_shape(num_vec, x_shape);
|
||||
out_min_shape = calc_shape(num_min_vec, x_min_shape);
|
||||
out_max_shape = calc_shape(num_max_vec, x_max_shape);
|
||||
return std::make_shared<abstract::Shape>(output_shape, out_min_shape, out_max_shape);
|
||||
}
|
||||
|
||||
TypePtr UnsortedSegmentSumInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
/* check segment_ids */
|
||||
auto ids_ptr = input_args[kInputIndex1]->BuildType();
|
||||
std::set<TypePtr> ids_type_set = {kInt32, kInt64};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("segment_ids type", ids_ptr, ids_type_set, prim_name);
|
||||
/* check num_segments */
|
||||
auto num_ptr = input_args[kInputIndex2]->BuildType();
|
||||
std::map<std::string, TypePtr> args_num_segments;
|
||||
(void)args_num_segments.insert({"num_segments", num_ptr});
|
||||
const std::set<TypePtr> num_type_set = {kInt32, kInt64};
|
||||
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_num_segments, num_type_set, prim_name);
|
||||
/* check input_x */
|
||||
auto x_type_ptr = input_args[kInputIndex0]->BuildType();
|
||||
std::set<TypePtr> x_type_set = {kFloat16, kFloat32, kInt32};
|
||||
return CheckAndConvertUtils::CheckTensorTypeValid("input_x", x_type_ptr, x_type_set, prim_name);
|
||||
}
|
||||
} // namespace
|
||||
MIND_API_OPERATOR_IMPL(UnsortedSegmentSum, BaseOperator);
|
||||
REGISTER_PRIMITIVE_C(kNameUnsortedSegmentSum, UnsortedSegmentSum);
|
||||
|
||||
AbstractBasePtr UnsortedSegmentSumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t kMinInputNum = 2;
|
||||
const int64_t kMaxInputNum = 3;
|
||||
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kMinInputNum, primitive->name());
|
||||
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kLessEqual, kMaxInputNum, primitive->name());
|
||||
auto infer_type = UnsortedSegmentSumInferType(primitive, input_args);
|
||||
auto infer_shape = UnsortedSegmentSumInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
|
||||
REGISTER_HOST_DEPENDS(kNameUnsortedSegmentSum, {2});
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(UnsortedSegmentSum, prim::kPrimUnsortedSegmentSum, UnsortedSegmentSumInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -23,17 +23,19 @@ unsorted_segment_sum_ds_op_info = TBERegOp("UnsortedSegmentSum") \
|
|||
.compute_cost(10) \
|
||||
.kernel_name("unsorted_segment_sum") \
|
||||
.partial_flag(True) \
|
||||
.need_check_supported(True) \
|
||||
.dynamic_compile_static(True) \
|
||||
.dynamic_shape(True) \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "segment_ids", False, "required", "all") \
|
||||
.input(2, "num_segments", False, "required", "all") \
|
||||
.input(1, "segment_ids", False, "required", "all", "optional") \
|
||||
.input(2, "num_segments", False, "required", "all", "optional") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.is_dynamic_format(True) \
|
||||
.dtype_format(DataType.None_None, DataType.None_None, DataType.None_None, DataType.None_None) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(unsorted_segment_sum_ds_op_info)
|
||||
def _unsorted_segment_sum_ds_tbe():
|
||||
"""UnsortedSegmentSumUnknown TBE register"""
|
||||
"""UnsortedSegmentSumD TBE register"""
|
||||
return
|
||||
|
|
|
@ -97,6 +97,7 @@ from .array_func import (
|
|||
affine_grid,
|
||||
fills,
|
||||
broadcast_to,
|
||||
unsorted_segment_sum,
|
||||
adaptive_max_pool2d,
|
||||
col2im,
|
||||
split,
|
||||
|
|
|
@ -80,6 +80,7 @@ zeros_like_ = P.ZerosLike()
|
|||
cast_ = P.Cast()
|
||||
tensor_select_ = P.Select()
|
||||
index_fill_ = IndexFill()
|
||||
unsorted_segment_sum_ = P.UnsortedSegmentSum()
|
||||
|
||||
|
||||
@constexpr
|
||||
|
@ -3738,6 +3739,62 @@ def min(x, axis=0, keep_dims=False):
|
|||
argmin_with_value_op = P.ArgMinWithValue(axis, keep_dims)
|
||||
return argmin_with_value_op(x)
|
||||
|
||||
|
||||
def unsorted_segment_sum(input_x, segment_ids, num_segments):
|
||||
r"""
|
||||
Computes the sum of a tensor along segments.
|
||||
|
||||
Calculates a tensor such that :math:`\text{output}[i] = \sum_{segment\_ids[j] == i} \text{data}[j, \ldots]`, where
|
||||
:math:`j` is a tuple describing the index of element in data. `segment_ids` selects which elements in data to sum
|
||||
up. Segment_ids does not need to be sorted, and it does not need to cover all values in the entire valid value
|
||||
range.
|
||||
|
||||
The following figure shows the calculation process of UnsortedSegmentSum:
|
||||
|
||||
.. image:: UnsortedSegmentSum.png
|
||||
|
||||
Note:
|
||||
- If the segment_id i is absent in the segment_ids, then output[i] will be filled with 0.
|
||||
- On Ascend, if the value of segment_id is less than 0 or greater than the length of the input data shape, an
|
||||
execution error will occur.
|
||||
|
||||
If the sum of the given segment_ids :math:`i` is empty, then :math:`\text{output}[i] = 0`. If the given segment_ids
|
||||
is negative, the value will be ignored. 'num_segments' must be equal to the number of different segment_ids.
|
||||
|
||||
Args:
|
||||
- **input_x** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_R)`.
|
||||
- **segment_ids** (Tensor) - Set the shape as :math:`(x_1, x_2, ..., x_N)`, where 0 < N <= R.
|
||||
- **num_segments** (int) - Set :math:`z` as num_segments.
|
||||
|
||||
Returns:
|
||||
Tensor, the shape is :math:`(z, x_{N+1}, ..., x_R)`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `num_segments` is not an int.
|
||||
ValueError: If length of shape of `segment_ids` is less than 1.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> from mindspore import Tensor
|
||||
>>> import mindspore.ops as ops
|
||||
>>> input_x = Tensor([1, 2, 3, 4], mindspore.float32)
|
||||
>>> segment_ids = Tensor([0, 0, 1, 2], mindspore.int32)
|
||||
>>> num_segments = 4
|
||||
>>> output = ops.UnsortedSegmentSum(input_x, segment_ids, num_segments)
|
||||
>>> print(output)
|
||||
[3. 3. 4. 0.]
|
||||
>>> input_x = Tensor([1, 2, 3, 4, 2, 5], mindspore.float32)
|
||||
>>> segment_ids = Tensor([0, 0, 1, 2, 3, 4], mindspore.int32)
|
||||
>>> num_segments = 6
|
||||
>>> output = ops.UnsortedSegmentSum(input_x, segment_ids, num_segments)
|
||||
>>> print(output)
|
||||
[3. 3. 4. 2. 5. 0.]
|
||||
"""
|
||||
return unsorted_segment_sum_(input_x, segment_ids, num_segments)
|
||||
|
||||
|
||||
__all__ = [
|
||||
'unique',
|
||||
'unique_with_pad',
|
||||
|
@ -3816,5 +3873,6 @@ __all__ = [
|
|||
"index_fill",
|
||||
'max',
|
||||
'min',
|
||||
'unsorted_segment_sum',
|
||||
]
|
||||
__all__.sort()
|
||||
|
|
|
@ -916,5 +916,6 @@ tensor_operator_registry.register('renorm', renorm)
|
|||
tensor_operator_registry.register('adaptive_max_pool2d', AdaptiveMaxPool2D)
|
||||
tensor_operator_registry.register('coalesce', coalesce)
|
||||
tensor_operator_registry.register('arg_min_with_value', min)
|
||||
tensor_operator_registry.register('unsorted_segment_sum', P.UnsortedSegmentSum)
|
||||
__all__ = [name for name in dir() if name[0] != "_"]
|
||||
__all__.remove('Primitive')
|
||||
|
|
|
@ -2406,34 +2406,7 @@ class UnsortedSegmentSum(PrimitiveWithInfer):
|
|||
r"""
|
||||
Computes the sum of a tensor along segments.
|
||||
|
||||
Calculates a tensor such that :math:`\text{output}[i] = \sum_{segment\_ids[j] == i} \text{data}[j, \ldots]`, where
|
||||
:math:`j` is a tuple describing the index of element in data. `segment_ids` selects which elements in data to sum
|
||||
up. Segment_ids does not need to be sorted, and it does not need to cover all values in the entire valid value
|
||||
range.
|
||||
|
||||
The following figure shows the calculation process of UnsortedSegmentSum:
|
||||
|
||||
.. image:: UnsortedSegmentSum.png
|
||||
|
||||
Note:
|
||||
- If the segment_id i is absent in the segment_ids, then output[i] will be filled with 0.
|
||||
- On Ascend, if the value of segment_id is less than 0 or greater than the length of the input data shape, an
|
||||
execution error will occur.
|
||||
|
||||
If the sum of the given segment_ids :math:`i` is empty, then :math:`\text{output}[i] = 0`. If the given segment_ids
|
||||
is negative, the value will be ignored. 'num_segments' must be equal to the number of different segment_ids.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_R)`.
|
||||
- **segment_ids** (Tensor) - Set the shape as :math:`(x_1, x_2, ..., x_N)`, where 0 < N <= R.
|
||||
- **num_segments** (int) - Set :math:`z` as num_segments.
|
||||
|
||||
Outputs:
|
||||
Tensor, the shape is :math:`(z, x_{N+1}, ..., x_R)`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `num_segments` is not an int.
|
||||
ValueError: If length of shape of `segment_ids` is less than 1.
|
||||
Refer to :func:`mindspore.ops.unsorted_segment_sum` for more detail.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
|
Loading…
Reference in New Issue