!38293 [MS][LITE] UnsortedSegmentSum dynamic shape fix and code docs fix
Merge pull request !38293 from luoyuan/fix-unsortedsegmentsum-ops-infershape-and-docs
This commit is contained in:
commit
5256155f29
|
@ -5,29 +5,4 @@
|
|||
|
||||
沿分段计算输入Tensor元素的和。
|
||||
|
||||
计算输出Tensor :math:`\text{output}[i] = \sum_{segment\_ids[j] == i} \text{data}[j, \ldots]` ,其中 :math:`j,...` 是代表元素索引的Tuple。 `segment_ids` 确定输入Tensor元素的分段。 `segment_ids` 不需要排序,也不需要覆盖 `num_segments` 范围内的所有值。
|
||||
|
||||
UnsortedSegmentSum的计算过程如下图所示:
|
||||
|
||||
.. image:: UnsortedSegmentSum.png
|
||||
|
||||
.. note::
|
||||
- 如果 `segment_ids` 中不存在segment_id `i` ,则对输出 `output[i]` 填充0。
|
||||
- 在Ascend平台上,如果segment_id的值小于0或大于输入Tensor的shape的长度,将触发执行错误。
|
||||
|
||||
如果 `segment_ids` 元素为负数,将忽略该值。 `num_segments` 必须等于不同segment_id的数量。
|
||||
|
||||
**输入:**
|
||||
|
||||
- **input_x** (Tensor) - shape: :math:`(x_1, x_2, ..., x_R)` 。
|
||||
- **segment_ids** (Tensor) - shape为 :math:`(x_1)` 的1维张量,值必须是非负数。数据类型支持int32。
|
||||
- **num_segments** (int) - 分段数量 :math:`z` 。
|
||||
|
||||
**输出:**
|
||||
|
||||
Tensor,shape: :math:`(z, x_{N+1}, ..., x_R)` 。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** - `num_segments` 不是int类型。
|
||||
- **ValueError** - `segment_ids` 的维度小于1。
|
||||
更多参考详见 :func:`mindspore.ops.unsorted_segment_sum`。
|
||||
|
|
|
@ -42,27 +42,27 @@ abstract::ShapePtr UnsortedSegmentSumInferShape(const PrimitivePtr &primitive,
|
|||
auto segment_ids_shape_rank = SizeToLong(segment_ids_shape.size());
|
||||
(void)CheckAndConvertUtils::CheckInteger("segment_ids size", segment_ids_shape_rank, kGreaterThan, 0, op_name);
|
||||
ShapeVector output_shape;
|
||||
ShapeVector out_min_shape;
|
||||
ShapeVector out_max_shape;
|
||||
constexpr int dynamic_rank_len = 1;
|
||||
constexpr int dynamic_rank_value = -2;
|
||||
if ((x_shape_rank == dynamic_rank_len && x_shape[0] == dynamic_rank_value) ||
|
||||
(segment_ids_shape_rank == dynamic_rank_len && segment_ids_shape[0] == dynamic_rank_value)) {
|
||||
output_shape = {dynamic_rank_value}; // unknown dimension
|
||||
out_min_shape = {0};
|
||||
out_max_shape = {abstract::Shape::SHP_ANY};
|
||||
return std::make_shared<abstract::Shape>(output_shape, out_min_shape, out_max_shape);
|
||||
return std::make_shared<abstract::Shape>(output_shape);
|
||||
}
|
||||
(void)CheckAndConvertUtils::CheckValue<size_t>("x rank", x_shape.size(), kGreaterEqual, "segment_ids_shape rank",
|
||||
segment_ids_shape.size(), op_name);
|
||||
for (uint64_t i = 0; i < segment_ids_shape.size(); i++) {
|
||||
if (segment_ids_shape[i] == abstract::Shape::SHP_ANY || x_shape[i] == abstract::Shape::SHP_ANY) continue;
|
||||
if (segment_ids_shape[i] != x_shape[i]) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << op_name
|
||||
<< "', the whose shape of 'segment_ids' must be a prefix of the shape of 'input_x', "
|
||||
"but got 'segment_ids_shape["
|
||||
<< i << "]': " << segment_ids_shape[i] << " and 'input_x_shape[" << i
|
||||
<< "]': " << x_shape[i];
|
||||
bool x_any_shape =
|
||||
std::any_of(x_shape.begin(), x_shape.end(), [](int64_t dim) { return dim == abstract::Shape::SHP_ANY; });
|
||||
bool ids_any_shape = std::any_of(segment_ids_shape.begin(), segment_ids_shape.end(),
|
||||
[](int64_t dim) { return dim == abstract::Shape::SHP_ANY; });
|
||||
if (!x_any_shape && !ids_any_shape) {
|
||||
for (uint64_t i = 0; i < segment_ids_shape.size(); i++) {
|
||||
if (segment_ids_shape[i] != x_shape[i]) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << op_name
|
||||
<< "', the whose shape of 'segment_ids' must equal to the shape of 'input_x', ids["
|
||||
<< i << "] should be = input[" << i << "]: " << x_shape[i] << ", but got "
|
||||
<< segment_ids_shape[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
abstract::CheckShapeAnyAndPositive(op_name + " x_shape", x_shape);
|
||||
|
@ -93,9 +93,6 @@ abstract::ShapePtr UnsortedSegmentSumInferShape(const PrimitivePtr &primitive,
|
|||
if (num_min_value != nullptr && num_max_value != nullptr) {
|
||||
num_min_vec = GetValue<ShapeVector>(num_min_value);
|
||||
num_max_vec = GetValue<ShapeVector>(num_max_value);
|
||||
} else {
|
||||
num_min_vec.push_back(-1);
|
||||
num_max_vec.push_back(-1);
|
||||
}
|
||||
}
|
||||
} else if (input_args[kInputIndex2]->isa<abstract::AbstractScalar>()) {
|
||||
|
@ -115,10 +112,21 @@ abstract::ShapePtr UnsortedSegmentSumInferShape(const PrimitivePtr &primitive,
|
|||
(void)copy(x_shape.begin() + segment_ids_shape_rank, x_shape.end(), std::back_inserter(out_vec));
|
||||
return out_vec;
|
||||
};
|
||||
ShapeVector out_min_shape;
|
||||
ShapeVector out_max_shape;
|
||||
output_shape = calc_shape(num_vec, x_shape);
|
||||
out_min_shape = calc_shape(num_min_vec, x_min_shape);
|
||||
out_max_shape = calc_shape(num_max_vec, x_max_shape);
|
||||
return std::make_shared<abstract::Shape>(output_shape, out_min_shape, out_max_shape);
|
||||
auto output_shape_rank = SizeToLong(output_shape.size());
|
||||
auto out_max_shape_rank = SizeToLong(calc_shape(num_max_vec, x_max_shape).size());
|
||||
bool x_min_any_shape =
|
||||
std::any_of(x_min_shape.begin(), x_min_shape.end(), [](int64_t dim) { return dim == abstract::Shape::SHP_ANY; });
|
||||
bool x_max_any_shape =
|
||||
std::any_of(x_max_shape.begin(), x_max_shape.end(), [](int64_t dim) { return dim == abstract::Shape::SHP_ANY; });
|
||||
if ((out_max_shape_rank == output_shape_rank) && (!x_min_any_shape && !x_max_any_shape)) {
|
||||
out_min_shape = calc_shape(num_min_vec, x_min_shape);
|
||||
out_max_shape = calc_shape(num_max_vec, x_max_shape);
|
||||
return std::make_shared<abstract::Shape>(output_shape, out_min_shape, out_max_shape);
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(output_shape);
|
||||
}
|
||||
|
||||
TypePtr UnsortedSegmentSumInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
|
@ -126,17 +134,17 @@ TypePtr UnsortedSegmentSumInferType(const PrimitivePtr &primitive, const std::ve
|
|||
auto prim_name = primitive->name();
|
||||
/* check segment_ids */
|
||||
auto ids_ptr = input_args[kInputIndex1]->BuildType();
|
||||
std::set<TypePtr> ids_type_set = {kInt32, kInt64};
|
||||
std::set<TypePtr> ids_type_set = {kInt16, kInt32, kInt64};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("segment_ids type", ids_ptr, ids_type_set, prim_name);
|
||||
/* check num_segments */
|
||||
auto num_ptr = input_args[kInputIndex2]->BuildType();
|
||||
std::map<std::string, TypePtr> args_num_segments;
|
||||
(void)args_num_segments.insert({"num_segments", num_ptr});
|
||||
const std::set<TypePtr> num_type_set = {kInt32, kInt64};
|
||||
const std::set<TypePtr> num_type_set = {kInt16, kInt32, kInt64};
|
||||
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_num_segments, num_type_set, prim_name);
|
||||
/* check input_x */
|
||||
auto x_type_ptr = input_args[kInputIndex0]->BuildType();
|
||||
std::set<TypePtr> x_type_set = {kFloat16, kFloat32, kInt32};
|
||||
std::set<TypePtr> x_type_set = {kFloat16, kFloat32, kFloat64, kInt32, kInt64};
|
||||
return CheckAndConvertUtils::CheckTensorTypeValid("input_x", x_type_ptr, x_type_set, prim_name);
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -23,8 +23,6 @@ unsorted_segment_sum_ds_op_info = TBERegOp("UnsortedSegmentSum") \
|
|||
.compute_cost(10) \
|
||||
.kernel_name("unsorted_segment_sum") \
|
||||
.partial_flag(True) \
|
||||
.need_check_supported(True) \
|
||||
.dynamic_compile_static(True) \
|
||||
.dynamic_shape(True) \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "segment_ids", False, "required", "all", "optional") \
|
||||
|
|
|
@ -3967,9 +3967,9 @@ def unsorted_segment_sum(input_x, segment_ids, num_segments):
|
|||
is negative, the value will be ignored. 'num_segments' must be equal to the number of different segment_ids.
|
||||
|
||||
Args:
|
||||
- **input_x** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_R)`.
|
||||
- **segment_ids** (Tensor) - Set the shape as :math:`(x_1, x_2, ..., x_N)`, where 0 < N <= R.
|
||||
- **num_segments** (int) - Set :math:`z` as num_segments.
|
||||
input_x (Tensor): The shape is :math:`(x_1, x_2, ..., x_R)`.
|
||||
segment_ids (Tensor): Set the shape as :math:`(x_1, x_2, ..., x_N)`, where 0 < N <= R.
|
||||
num_segments (int): Set :math:`z` as num_segments.
|
||||
|
||||
Returns:
|
||||
Tensor, the shape is :math:`(z, x_{N+1}, ..., x_R)`.
|
||||
|
@ -3987,13 +3987,13 @@ def unsorted_segment_sum(input_x, segment_ids, num_segments):
|
|||
>>> input_x = Tensor([1, 2, 3, 4], mindspore.float32)
|
||||
>>> segment_ids = Tensor([0, 0, 1, 2], mindspore.int32)
|
||||
>>> num_segments = 4
|
||||
>>> output = ops.UnsortedSegmentSum(input_x, segment_ids, num_segments)
|
||||
>>> output = ops.unsorted_segment_sum(input_x, segment_ids, num_segments)
|
||||
>>> print(output)
|
||||
[3. 3. 4. 0.]
|
||||
>>> input_x = Tensor([1, 2, 3, 4, 2, 5], mindspore.float32)
|
||||
>>> segment_ids = Tensor([0, 0, 1, 2, 3, 4], mindspore.int32)
|
||||
>>> num_segments = 6
|
||||
>>> output = ops.UnsortedSegmentSum(input_x, segment_ids, num_segments)
|
||||
>>> output = ops.unsorted_segment_sum(input_x, segment_ids, num_segments)
|
||||
>>> print(output)
|
||||
[3. 3. 4. 2. 5. 0.]
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue