From 84a536d6b50709e5c5441d62b7eec4dac3bb4712 Mon Sep 17 00:00:00 2001 From: lilinjie Date: Mon, 5 Sep 2022 00:47:08 +0800 Subject: [PATCH] fix some issues --- .../ops/mindspore.ops.StridedSlice.rst | 20 +++++++++++-------- mindspore/core/ops/sign.cc | 4 +++- .../mindspore/ops/operations/array_ops.py | 17 ++++++++++------ 3 files changed, 26 insertions(+), 15 deletions(-) diff --git a/docs/api/api_python/ops/mindspore.ops.StridedSlice.rst b/docs/api/api_python/ops/mindspore.ops.StridedSlice.rst index fdf5beabdae..74ab2748562 100644 --- a/docs/api/api_python/ops/mindspore.ops.StridedSlice.rst +++ b/docs/api/api_python/ops/mindspore.ops.StridedSlice.rst @@ -3,15 +3,19 @@ .. py:class:: mindspore.ops.StridedSlice(begin_mask=0, end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=0) - 输入Tensor根据步长和索引进行切片提取。 + 对输入Tensor根据步长和索引进行切片提取。 - 给定一个输入Tensor,此操作会插入长度为1的维度。从给定的 `input_tensor` 中提取大小为 `(end-begin)/stride` 的片段。从起始位置开始,根据步长和索引进行提取,直到所有维度都不小于结束位置为止。 + 该算子在给定的 `input_tensor` 中提取大小为 `(end-begin)/stride` 的片段。从起始位置开始,根据步长和索引进行提取,直到所有维度的都不小于结束位置为止。 给定一个 `input_x[m1, m2, ...、mn]` 。 `begin` 、 `end` 和 `strides` 是长度为n的向量。 - 在每个掩码字段中(`begin_mask`、`end_mask`、`ellipsis_mask`、`new_axis_mask`、`shrink_ax_mask`),第i位将对应于第i个m。 + 在每个掩码字段中(`begin_mask`、`end_mask`、`ellipsis_mask`、`new_axis_mask`、`shrink_axis_mask`),第i位将对应于第i个m。 - 如果设置了 `begin_mask` 的第i位,则忽略 `begin[i]` ,而使用该维度中最有可能的取值范围。除了结尾的取值范围, `end_mask` 是类似的。 + 对每个特定的mask,内部先将各mask转化为二进制表示, 然后倒序排布后进行计算。比如说对于一个5*6*7的Tensor,mask设置为3, 3转化为二进制表示为ob011, 倒序后为ob110, + + 则该mask只在第0维和第1维产生作用, 下面各自举例说明。 + + 如果设置了 `begin_mask` 的第i位,则忽略 `begin[i]`,而使用该维度的最大可能取值范围,`end_mask`实现方式与之类似。 对于5*6*7的Tensor, `x[2:,:3,:]` 等同于 `x[2:5,0:3,0:7]` 。 @@ -19,13 +23,13 @@ 对于5*6*7*8的Tensor, `x[2:,...,:6]` 等同于 `x[2:5,:,:,0:6]` 。 `x[2:,...]` 等同于 `x[2:5,:,:,:]` 。 - 如果设置了 `new_ax_mask` 的第i位,则忽略 `begin` 、 `end` 和 `strides` ,并在输出Tensor的指定位置添加新的长度为1的维度。 + 如果设置了 `new_axis_mask` 的第i位,则忽略 `begin` 、 `end` 和 `strides` ,并在输出Tensor的指定位置添加新的长度为1的维度。 - 对于5*6*7的Tensor, `x[:2, newaxis, :6]` 将产生一个shape为 :math:`(2, 1, 7)` 的Tensor。 + 对于5*6*7的Tensor, `x[:2, newaxis, :6]` 将产生一个shape为 :math:`(2, 1, 6, 7)` 的Tensor。 - 如果设置了 `shrink_ax_mask` 的第i位,则第i个大小将维度收缩1,并忽略 `begin[i]` 、 `end[i]` 和 `strides[i]` 索引处的值。 + 如果设置了 `shrink_axis_mask` 的第i位,则第i维被收缩掉,并忽略 `begin[i]` 、 `end[i]` 和 `strides[i]` 索引处的值。 - 对于5*6*7的Tensor, `x[:, 5, :]` 将使得 `shrink_axis_mask` 等于4。 + 对于5*6*7的Tensor, `x[:, 5, :]` 相当于将 `shrink_axis_mask` 设置为2, 使得输出shape为:math:`(5, 7)` 。 .. note:: 步长可能为负值,这会导致反向切片。 `begin` 、 `end` 和 `strides` 的shape必须相同。 `begin` 和 `end` 是零索引。 `strides` 的元素必须非零。 diff --git a/mindspore/core/ops/sign.cc b/mindspore/core/ops/sign.cc index f7dd2ee1758..e6a28b4370e 100644 --- a/mindspore/core/ops/sign.cc +++ b/mindspore/core/ops/sign.cc @@ -54,7 +54,9 @@ AbstractBasePtr SignInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt const int64_t input_num = 1; CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name()); - return abstract::MakeAbstract(SignInferShape(primitive, input_args), SignInferType(primitive, input_args)); + auto infer_type = SignInferType(primitive, input_args); + auto infer_shape = SignInferShape(primitive, input_args); + return abstract::MakeAbstract(infer_shape, infer_type); } REGISTER_PRIMITIVE_EVAL_IMPL(Sign, prim::kPrimSign, SignInfer, nullptr, true); } // namespace ops diff --git a/mindspore/python/mindspore/ops/operations/array_ops.py b/mindspore/python/mindspore/ops/operations/array_ops.py index 2a8dd12107d..6024a3e4d31 100755 --- a/mindspore/python/mindspore/ops/operations/array_ops.py +++ b/mindspore/python/mindspore/ops/operations/array_ops.py @@ -3279,7 +3279,6 @@ class StridedSlice(PrimitiveWithInfer): Extracts a strided slice of a tensor. - Given an input tensor, this operation inserts a dimension of length 1 at the dimension. This operation extracts a fragment of size (end-begin)/stride from the given 'input_tensor'. Starting from the beginning position, the fragment continues adding stride to the index until all dimensions are not less than the ending position. @@ -3289,26 +3288,32 @@ class StridedSlice(PrimitiveWithInfer): In each mask field (`begin_mask`, `end_mask`, `ellipsis_mask`, `new_axis_mask`, `shrink_axis_mask`) the ith bit will correspond to the ith m. + For each mask, it will be converted to a binary representation internally, and then + reverse the result to start the calculation. For a 5*6*7 tensor with a given mask value of 3 which + can be represented as ob011. Reverse that we get ob110, which implies the first and second dim of the + original tensor will be effected by this mask. See examples below: + If the ith bit of `begin_mask` is set, `begin[i]` is ignored and the fullest possible range in that dimension is used instead. `end_mask` is analogous, except with the end range. - As for a 5*6*7 tensor, `x[2:,:3,:]` is equivalent to `x[2:5,0:3,0:7]`. + For a 5*6*7 tensor, `x[2:,:3,:]` is equivalent to `x[2:5,0:3,0:7]`. If the ith bit of `ellipsis_mask` is set, as many unspecified dimensions as needed will be inserted between other dimensions. Only one non-zero bit is allowed in `ellipsis_mask`. - As for a 5*6*7*8 tensor, `x[2:,...,:6]` is equivalent to `x[2:5,:,:,0:6]`. + For a 5*6*7*8 tensor, `x[2:,...,:6]` is equivalent to `x[2:5,:,:,0:6]`. `x[2:,...]` is equivalent to `x[2:5,:,:,:]`. If the ith bit of `new_axis_mask` is set, `begin`, `end` and `strides` are ignored and a new length 1 dimension is added at the specified position in the output tensor. - As for a 5*6*7 tensor, `x[:2, newaxis, :6]` will produce a tensor with shape :math:`(2, 1, 7)` . + For a 5*6*7 tensor, `x[:2, newaxis, :6]` will produce a tensor with shape :math:`(2, 1, 6, 7)` . - If the ith bit of `shrink_axis_mask` is set, ith size shrinks the dimension by 1, taking on the value + If the ith bit of `shrink_axis_mask` is set, dimension i will be shrunk to 0, taking on the value at index `begin[i]`, `end[i]` and `strides[i]` are ignored. - As for a 5*6*7 tensor, `x[:, 5, :]` will result in `shrink_axis_mask` equal to 4. + For a 5*6*7 tensor, `x[:, 5, :]` is equivalent to setting the `shrink_axis_mask` to 2 which results in + an out shape of :math:`(5, 7)`. Note: The stride may be negative value, which causes reverse slicing.