!38293 [MS][LITE] UnsortedSegmentSum dynamic shape fix and code docs fix

Merge pull request !38293 from luoyuan/fix-unsortedsegmentsum-ops-infershape-and-docs
This commit is contained in:
i-robot 2022-07-19 09:28:22 +00:00 committed by Gitee
commit 5256155f29
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 36 additions and 55 deletions

View File

@ -5,29 +5,4 @@
沿分段计算输入Tensor元素的和。
计算输出Tensor :math:`\text{output}[i] = \sum_{segment\_ids[j] == i} \text{data}[j, \ldots]` ,其中 :math:`j,...` 是代表元素索引的Tuple。 `segment_ids` 确定输入Tensor元素的分段。 `segment_ids` 不需要排序,也不需要覆盖 `num_segments` 范围内的所有值。
UnsortedSegmentSum的计算过程如下图所示
.. image:: UnsortedSegmentSum.png
.. note::
- 如果 `segment_ids` 中不存在segment_id `i` ,则对输出 `output[i]` 填充0。
- 在Ascend平台上如果segment_id的值小于0或大于输入Tensor的shape的长度将触发执行错误。
如果 `segment_ids` 元素为负数,将忽略该值。 `num_segments` 必须等于不同segment_id的数量。
**输入:**
- **input_x** (Tensor) - shape :math:`(x_1, x_2, ..., x_R)`
- **segment_ids** (Tensor) - shape为 :math:`(x_1)` 的1维张量值必须是非负数。数据类型支持int32。
- **num_segments** (int) - 分段数量 :math:`z`
**输出:**
Tensorshape :math:`(z, x_{N+1}, ..., x_R)`
**异常:**
- **TypeError** - `num_segments` 不是int类型。
- **ValueError** - `segment_ids` 的维度小于1。
更多参考详见 :func:`mindspore.ops.unsorted_segment_sum`

View File

@ -42,27 +42,27 @@ abstract::ShapePtr UnsortedSegmentSumInferShape(const PrimitivePtr &primitive,
auto segment_ids_shape_rank = SizeToLong(segment_ids_shape.size());
(void)CheckAndConvertUtils::CheckInteger("segment_ids size", segment_ids_shape_rank, kGreaterThan, 0, op_name);
ShapeVector output_shape;
ShapeVector out_min_shape;
ShapeVector out_max_shape;
constexpr int dynamic_rank_len = 1;
constexpr int dynamic_rank_value = -2;
if ((x_shape_rank == dynamic_rank_len && x_shape[0] == dynamic_rank_value) ||
(segment_ids_shape_rank == dynamic_rank_len && segment_ids_shape[0] == dynamic_rank_value)) {
output_shape = {dynamic_rank_value}; // unknown dimension
out_min_shape = {0};
out_max_shape = {abstract::Shape::SHP_ANY};
return std::make_shared<abstract::Shape>(output_shape, out_min_shape, out_max_shape);
return std::make_shared<abstract::Shape>(output_shape);
}
(void)CheckAndConvertUtils::CheckValue<size_t>("x rank", x_shape.size(), kGreaterEqual, "segment_ids_shape rank",
segment_ids_shape.size(), op_name);
for (uint64_t i = 0; i < segment_ids_shape.size(); i++) {
if (segment_ids_shape[i] == abstract::Shape::SHP_ANY || x_shape[i] == abstract::Shape::SHP_ANY) continue;
if (segment_ids_shape[i] != x_shape[i]) {
MS_EXCEPTION(ValueError) << "For '" << op_name
<< "', the whose shape of 'segment_ids' must be a prefix of the shape of 'input_x', "
"but got 'segment_ids_shape["
<< i << "]': " << segment_ids_shape[i] << " and 'input_x_shape[" << i
<< "]': " << x_shape[i];
bool x_any_shape =
std::any_of(x_shape.begin(), x_shape.end(), [](int64_t dim) { return dim == abstract::Shape::SHP_ANY; });
bool ids_any_shape = std::any_of(segment_ids_shape.begin(), segment_ids_shape.end(),
[](int64_t dim) { return dim == abstract::Shape::SHP_ANY; });
if (!x_any_shape && !ids_any_shape) {
for (uint64_t i = 0; i < segment_ids_shape.size(); i++) {
if (segment_ids_shape[i] != x_shape[i]) {
MS_EXCEPTION(ValueError) << "For '" << op_name
<< "', the whose shape of 'segment_ids' must equal to the shape of 'input_x', ids["
<< i << "] should be = input[" << i << "]: " << x_shape[i] << ", but got "
<< segment_ids_shape[i];
}
}
}
abstract::CheckShapeAnyAndPositive(op_name + " x_shape", x_shape);
@ -93,9 +93,6 @@ abstract::ShapePtr UnsortedSegmentSumInferShape(const PrimitivePtr &primitive,
if (num_min_value != nullptr && num_max_value != nullptr) {
num_min_vec = GetValue<ShapeVector>(num_min_value);
num_max_vec = GetValue<ShapeVector>(num_max_value);
} else {
num_min_vec.push_back(-1);
num_max_vec.push_back(-1);
}
}
} else if (input_args[kInputIndex2]->isa<abstract::AbstractScalar>()) {
@ -115,10 +112,21 @@ abstract::ShapePtr UnsortedSegmentSumInferShape(const PrimitivePtr &primitive,
(void)copy(x_shape.begin() + segment_ids_shape_rank, x_shape.end(), std::back_inserter(out_vec));
return out_vec;
};
ShapeVector out_min_shape;
ShapeVector out_max_shape;
output_shape = calc_shape(num_vec, x_shape);
out_min_shape = calc_shape(num_min_vec, x_min_shape);
out_max_shape = calc_shape(num_max_vec, x_max_shape);
return std::make_shared<abstract::Shape>(output_shape, out_min_shape, out_max_shape);
auto output_shape_rank = SizeToLong(output_shape.size());
auto out_max_shape_rank = SizeToLong(calc_shape(num_max_vec, x_max_shape).size());
bool x_min_any_shape =
std::any_of(x_min_shape.begin(), x_min_shape.end(), [](int64_t dim) { return dim == abstract::Shape::SHP_ANY; });
bool x_max_any_shape =
std::any_of(x_max_shape.begin(), x_max_shape.end(), [](int64_t dim) { return dim == abstract::Shape::SHP_ANY; });
if ((out_max_shape_rank == output_shape_rank) && (!x_min_any_shape && !x_max_any_shape)) {
out_min_shape = calc_shape(num_min_vec, x_min_shape);
out_max_shape = calc_shape(num_max_vec, x_max_shape);
return std::make_shared<abstract::Shape>(output_shape, out_min_shape, out_max_shape);
}
return std::make_shared<abstract::Shape>(output_shape);
}
TypePtr UnsortedSegmentSumInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
@ -126,17 +134,17 @@ TypePtr UnsortedSegmentSumInferType(const PrimitivePtr &primitive, const std::ve
auto prim_name = primitive->name();
/* check segment_ids */
auto ids_ptr = input_args[kInputIndex1]->BuildType();
std::set<TypePtr> ids_type_set = {kInt32, kInt64};
std::set<TypePtr> ids_type_set = {kInt16, kInt32, kInt64};
(void)CheckAndConvertUtils::CheckTensorTypeValid("segment_ids type", ids_ptr, ids_type_set, prim_name);
/* check num_segments */
auto num_ptr = input_args[kInputIndex2]->BuildType();
std::map<std::string, TypePtr> args_num_segments;
(void)args_num_segments.insert({"num_segments", num_ptr});
const std::set<TypePtr> num_type_set = {kInt32, kInt64};
const std::set<TypePtr> num_type_set = {kInt16, kInt32, kInt64};
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_num_segments, num_type_set, prim_name);
/* check input_x */
auto x_type_ptr = input_args[kInputIndex0]->BuildType();
std::set<TypePtr> x_type_set = {kFloat16, kFloat32, kInt32};
std::set<TypePtr> x_type_set = {kFloat16, kFloat32, kFloat64, kInt32, kInt64};
return CheckAndConvertUtils::CheckTensorTypeValid("input_x", x_type_ptr, x_type_set, prim_name);
}
} // namespace

View File

@ -23,8 +23,6 @@ unsorted_segment_sum_ds_op_info = TBERegOp("UnsortedSegmentSum") \
.compute_cost(10) \
.kernel_name("unsorted_segment_sum") \
.partial_flag(True) \
.need_check_supported(True) \
.dynamic_compile_static(True) \
.dynamic_shape(True) \
.input(0, "x", False, "required", "all") \
.input(1, "segment_ids", False, "required", "all", "optional") \

View File

@ -3967,9 +3967,9 @@ def unsorted_segment_sum(input_x, segment_ids, num_segments):
is negative, the value will be ignored. 'num_segments' must be equal to the number of different segment_ids.
Args:
- **input_x** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_R)`.
- **segment_ids** (Tensor) - Set the shape as :math:`(x_1, x_2, ..., x_N)`, where 0 < N <= R.
- **num_segments** (int) - Set :math:`z` as num_segments.
input_x (Tensor): The shape is :math:`(x_1, x_2, ..., x_R)`.
segment_ids (Tensor): Set the shape as :math:`(x_1, x_2, ..., x_N)`, where 0 < N <= R.
num_segments (int): Set :math:`z` as num_segments.
Returns:
Tensor, the shape is :math:`(z, x_{N+1}, ..., x_R)`.
@ -3987,13 +3987,13 @@ def unsorted_segment_sum(input_x, segment_ids, num_segments):
>>> input_x = Tensor([1, 2, 3, 4], mindspore.float32)
>>> segment_ids = Tensor([0, 0, 1, 2], mindspore.int32)
>>> num_segments = 4
>>> output = ops.UnsortedSegmentSum(input_x, segment_ids, num_segments)
>>> output = ops.unsorted_segment_sum(input_x, segment_ids, num_segments)
>>> print(output)
[3. 3. 4. 0.]
>>> input_x = Tensor([1, 2, 3, 4, 2, 5], mindspore.float32)
>>> segment_ids = Tensor([0, 0, 1, 2, 3, 4], mindspore.int32)
>>> num_segments = 6
>>> output = ops.UnsortedSegmentSum(input_x, segment_ids, num_segments)
>>> output = ops.unsorted_segment_sum(input_x, segment_ids, num_segments)
>>> print(output)
[3. 3. 4. 2. 5. 0.]
"""