!45477 expsoe Primitve APIs part2

Merge pull request !45477 from 李林杰/code_docs_expose_Primitive_APIs_part2
This commit is contained in:
i-robot 2022-11-18 12:00:28 +00:00 committed by Gitee
commit 1a14ff5c87
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
22 changed files with 459 additions and 82 deletions

View File

@ -214,7 +214,11 @@ MindSpore中 `mindspore.ops` 接口与上一版本相比,新增、删除和支
mindspore.ops.L2Normalize
mindspore.ops.NMSWithMask
mindspore.ops.NonMaxSuppressionWithOverlaps
mindspore.ops.PSROIPooling
mindspore.ops.RGBToHSV
mindspore.ops.ResizeArea
mindspore.ops.ResizeBicubic
mindspore.ops.ResizeBilinearV2
mindspore.ops.ROIAlign
mindspore.ops.SampleDistortedBoundingBoxV2
mindspore.ops.ScaleAndTranslate
@ -397,6 +401,7 @@ Reduction算子
mindspore.ops.Orgqr
mindspore.ops.Svd
mindspore.ops.TridiagonalMatMul
mindspore.ops.Qr
Tensor操作算子
----------------
@ -431,8 +436,11 @@ Tensor创建
mindspore.ops.LogNormalReverse
mindspore.ops.Multinomial
mindspore.ops.NonDeterministicInts
mindspore.ops.ParameterizedTruncatedNormal
mindspore.ops.RandomCategorical
mindspore.ops.RandomChoiceWithMask
mindspore.ops.RandomGamma
mindspore.ops.RandomPoisson
mindspore.ops.Randperm
mindspore.ops.StandardLaplace
mindspore.ops.StandardNormal
@ -509,8 +517,10 @@ Array操作
mindspore.ops.Nonzero
mindspore.ops.ParallelConcat
mindspore.ops.PopulationCount
mindspore.ops.RaggedRange
mindspore.ops.Range
mindspore.ops.Rank
mindspore.ops.Renorm
mindspore.ops.Reshape
mindspore.ops.ResizeNearestNeighborV2
mindspore.ops.ReverseSequence

View File

@ -0,0 +1,32 @@
mindspore.ops.PadAndShift
==========================
.. py:class:: mindspore.ops.PadAndShift
使用-1初始化一个Tensor然后从 `input_x` 转移一个切片到该Tensor。
.. note::
如果在Python中使用PadAndShift按下面流程得到输出Tensor
output = [-1] * cum_sum_arr[-1]
start = cum_sum_arr[shift_idx]
end = cum_sum_arr[shift_idx + 1]
output[start:end] = input_x[:(end-start)]
输入:
- **input_x** (Tensor) - 输入Tensor将被转移到 `output`
- **cum_sum_arr** (Tensor) - `cum_sum_arr` 的最后一个值是输出Tensor的长度 `cum_sum_arr[shift_idx]` 是转移起点, `cum_sum_arr[shift_idx+1]` 是转移终点。
- **shift_idx** (int) - `cum_sum_arr` 的下标。
输出:
- **output** (Tensor) - Tensor数据类型与 `input` 一致。
异常:
- **TypeError** - `input_x` 或者 `cum_sum_arr` 不是Tensor。
- **TypeError** - `shift_idx` 不是int。
- **ValueError** - `shift_idx` 的值大于等于 `cum_sum_arr` 的长度。

View File

@ -0,0 +1,34 @@
mindspore.ops.ParameterizedTruncatedNormal
===========================================
.. py:class:: mindspore.ops.ParameterizedTruncatedNormal(seed=0, seed2=0)
返回一个具有指定shape的Tensor其数值取自截断正态分布。
当其shape为 :math:`(batch_size, *)` 的时候, `mean``stdevs``min``max` 的shape应该为 :math:`()` 或者 :math:`(batch_size, )`
.. note::
在广播之后,在任何位置, `min` 的值必须严格小于 `max` 的值。
参数:
- **seed** (int可选) - 随机数种子。如果 `seed` 或者 `seed2` 被设置为非零则使用这个非零值。否则使用一个随机生成的种子。默认值0。
- **seed2** (int可选) - 另一个随机种子避免发生冲突。默认值0。
输入:
- **shape** (Tensor) - 生成Tensor的shape。数据类型必须是int32或者int64。
- **mean** (Tensor) - 截断正态分布均值。数据类型必须是float16、float32或者float64。
- **stdevs** (Tensor) - 截断正态分布的标准差。其值必须大于零,数据类型与 `mean` 一致。
- **min** (Tensor) - 最小截断值,数据类型与 `mean` 一致。
- **max** (Tensor) - 最大截断值,数据类型与 `mean` 一致。
输出:
Tensor其shape由 `shape` 决定,数据类型与 `mean` 一致。
异常:
- **TypeError** - `shape``mean``stdevs``min``max` 数据类型不支持。
- **TypeError** - `mean``stdevs``min``max` 的shape不一致。
- **TypeError** - `shape``mean``stdevs``min``max` 不全是Tensor。
- **ValueError** - 当其 `shape`:math:`(batch_size, *)` 时, `mean``stdevs``min` 或者 `max` 的shape不是 :math:`()` 或者 :math:`(batch_size, )`
- **ValueError** - `shape` 的元素不全大于零。
- **ValueError** - `stdevs` 的值不全大于零。
- **ValueError** - `shape` 的的元素个数小于2。
- **ValueError** - `shape` 不是一维Tensor。

View File

@ -0,0 +1,26 @@
mindspore.ops.PSROIPooling
==========================
.. py:class:: mindspore.ops.PSROIPooling(spatial_scale, group_size, output_dim)
对输入Tensor应用Position Sensitive ROI-Pooling。
参数:
- **spatial_scale** (float) - 将框坐标映射到输入坐标的比例因子。例如如果你的框定义在224x224的图像上并且你的输入是112x112的特征图由原始图像的0.5倍缩放产生此时需要将其设置为0.5。
- **group_size** (int) - 执行池化后输出的大小(以像素为单位),以(高度,宽度)的格式输出。
- **output_dim** (int) -执行池化后输出的维度。
输入:
- **features** (Tensor) - 输入特征Tensor其shape必须为 :math:`(N, C, H, W)` 。 各维度的值应满足: :math:`(C == output_dim * group_size * group_size)` 。数据类型为float16或者float32。
- **rois** (Tensor) - 其shape为 :math:`(batch, 5, rois_n)` 数据类型为float16或者float32。第一个维度的batch为批处理大小。第二个维度的大小必须为5。第三维度rois_n是rois的数量。rois_n的值格式为(index, x1, y1, x2, y2)。其中第一个元素是rois的索引。方框坐标格式为(x1、y1、x2、y2)之后将把这些方框的选中的区域提取出来。区域坐标必须满足0 <= x1 < x2和0 <= y1 < y2。
输出:
- **out** (Tensor) - 池化后的输出。其shape为 :math:`(rois.shape[0] * rois.shape[2], output\_dim, group\_size, group\_size)`
异常:
- **TypeError** - `spatial_scale` 不是float类型。
- **TypeError** - `group_size` 或者 `output_dim` 不是 int类型。
- **TypeError** - `features` 或者 `rois` 不是Tensor。
- **TypeError** - `rois` 数据类型不是float16或者float32。
- **ValueError** - `features` 的shape不满足 :math:`(C == output_dim * group_size * group_size)`
- **ValueError** - `spatial_scale` 为负数。

View File

@ -0,0 +1,23 @@
mindspore.ops.Qr
=================
.. py:class:: mindspore.ops.Qr(full_matrices=False)
返回一个或多个矩阵的QR正交三角分解。如果 `full_matrices` 设为True则计算全尺寸q和r如果为False默认值则计算q的P列其中P是 `x` 的2个最内层维度中的最小值。
参数:
- **full_matrices** (bool可选) - 是否进行全尺寸的QR分解。默认值False。
输入:
- **x** (Tensor) - 要进行分解的矩阵。矩阵必须至少为二维。数据类型float16、float32、float64、complex64、complex128。
`x` 的shape定义为 :math:`(..., m, n)` p定义为m和n的最小值。
输出:
- **q** (Tensor) - `x` 的正交矩阵。如果 `full_matrices` 为True则shape为 :math:`(m, m)` 否则shape为 :math:`(m, p)``q` 的数据类型与 `x` 相同。
- **r** (Tensor) - `x` 的上三角形矩阵。如果 `full_matrices` 为True则shape为 :math:`(m, n)` 否则shape为 :math:`(p, n)``r` 的数据类型与 `x` 相同。
异常:
- **TypeError** - `x` 不是Tensor。
- **TypeError** - `full_matrices` 不是bool。
- **TypeError** - `x` 的维度小于2。

View File

@ -0,0 +1,29 @@
mindspore.ops.RaggedRange
==========================
.. py:class:: mindspore.ops.RaggedRange(Tsplits)
返回包含指定数数列的RaggedTensor。
参数:
- **Tsplits** (mindspore.dtype) - 输出的类型。它的值必须是mstype.int32或者mstype.int64。
输入:
- **starts** (Tensor) - 每个数列的开始。是一个 0D或1D Tensor数据类型为int32、int64、float32或float64。
- **limits** (Tensor) - 每个数列的上限shape与数据类型与 `starts` 一致。
- **deltas** (Tensor) - 每个数列增量shape与数据类型与 `starts` 一致其中所有元素的值不能为0。
输出:
- **rt_nested_splits** (Tensor) - 返回RagdTensor的嵌套拆分Tensor数据类型类型为 `Tsplits` 。shape等于输入 `starts` 的shape加1。
- **rt_dense_values** (Tensor) - 返回RagdTensor的密集值Tensor其数据类型与输入 `starts` 相同。设输入 `starts、` `limits``delta` 的大小为i。
- 如果 `starts``limits``delta` 的数据类型为int32或int64则输出 `rt_dense_values` 的shape等于 :math:`sum(abs(limits[i] - starts[i]) + abs(deltas[i]) - 1) / abs(deltas[i]))`
- 如果 `starts``limits``delta` 的数据类型为float32或者float64则输出 `rt_dense_values` 的shape等于 :math:`sum(ceil(abs((limits[i] - starts[i]) / deltas[i]))`
异常:
- **TypeError** - 如任意一个输入不是Tensor。
- **TypeError** - 如果 `starts` 的数据类型不是int32、int64、float32或float64。
- **TypeError** - 如果 `starts``limits``deltas` 的数据类型不一致。
- **TypeError** - 如果 `Tsplits` 不是mstype.int32或者mstype.int64。
- **ValueError** - 如果 `starts``limits``deltas` 不是 0D或1D Tensor。
- **ValueError** - 如果 `deltas` 等于0。
- **ValueError** - 如果 `starts``limits``deltas` 的shape不一致。

View File

@ -0,0 +1,35 @@
mindspore.ops.RandomGamma
==========================
.. py:class:: mindspore.ops.RandomGamma(seed=0, seed2=0)
根据概率密度函数分布生成随机正值浮点数x
.. math::
\text{P}(x|α,β) = \frac{\exp(-x/β)}{{β^α}\cdot{\Gamma(α)}}\cdot{x^{α-1}}
.. note::
- 随机种子:通过一些复杂的数学算法,可以得到一组有规律的随机数,而随机种子就是这个随机数的初始值。随机种子相同,得到的随机数就不会改变。
- 全局的随机种子和算子层的随机种子都没设置:使用默认值当做随机种子。
- 全局的随机种子设置了,算子层的随机种子未设置:随机生成一个种子和全局的随机种子拼接。
- 全局的随机种子未设置,算子层的随机种子设置了:使用默认的全局的随机种子,和算子层的随机种子拼接。
- 全局的随机种子和算子层的随机种子都设置了:全局的随机种子和算子层的随机种子拼接。
参数:
- **seed** (int可选) - 算子层的随机种子用于生成随机数。必须是非负的。默认值0。
- **seed2** (int可选) - 全局的随机种子和算子层的随机种子共同决定最终生成的随机数。必须是非负的。默认值0。
输入:
- **shape** (tuple) - 待生成的随机Tensor的shape。只支持常量值。
- **alpha** (Tensor) - α为Gamma分布的shape parameter主要决定了曲线的形状。其值必须大于0。数据类型为float32。
- **beta** (Tensor) - β为Gamma分布的inverse scale parameter主要决定了曲线有多陡。其值必须大于0。数据类型为float32。
输出:
Tensor。shape是输入 `shape` `alpha` `beta` 广播后的shape。数据类型为float32。
异常:
- **TypeError** - `seed``seed2` 的数据类型不是int。
- **TypeError** - `alpha``beta` 不是Tensor。
- **TypeError** - `alpha``beta` 的数据类型不是float32。
- **ValueError** - `shape` 不是常量值。

View File

@ -0,0 +1,27 @@
mindspore.ops.RandomPoisson
============================
.. py:class:: mindspore.ops.RandomPoisson(seed=0, seed2=0, dtype=mindspore.int64)
根据离散概率密度函数分布生成随机非负数浮点数i
.. math::
\text{P}(i|μ) = \frac{\exp(-μ)μ^{i}}{i!}
参数:
- **seed** (int可选) - 随机数种子。如果 `seed` 或者 `seed2` 被设置为非零则使用这个非零值。否则使用一个随机生成的种子。默认值0。
- **seed2** (int可选) - 另一个随机种子避免发生冲突。默认值0。
- **dtype** (mindspore.dtype可选) - 输出数据类型, 默人值mindspore.int64。
输入:
- **shape** (tuple) - 待生成的随机Tensor的shape是一个一维Tensor。数据类型为nt32或int64。
- **rate** (Tensor) - `rate` 为Poisson分布的μ参数决定数字的平均出现次数。数据类型是其中之一[float16, float32, float64, int32, int64]。
输出:
Tensor。shape是 :math:`(*shape, *rate.shape)` ,数据类型由参数 `dtype` 指定。
异常:
- **TypeError** - `shape` 不是Tensor或数据类型不是int32或int64。
- **TypeError** - `dtype` 数据类型不是int32或int64。
- **TypeError** - `shape` 不是一维Tensor。
- **ValueError** - `shape` 的元素存在负数。

View File

@ -0,0 +1,22 @@
mindspore.ops.Receive
======================
.. py:class:: mindspore.ops.Receive(sr_tag, src_rank, shape, dtype, group="hccl_world_group/nccl_world_group")
从src_rank接收张量。
.. note::
Send和Receive必须组合使用并且具有相同的sr_tag。Receive必须在服务器之间使用。
参数:
- **sr_tag** (int) - 标识发送/接收消息标签的所需的整数。消息将将由具有相同 `sr_tag` 的Send算子发送。
- **src_rank** (int) - 标识设备rank的所需整数。
- **shape** (list[int]) - 标识要接收的Tensor的shape的所需列表。
- **dtype** (Type) - 标识要接收的Tensor类型的必要类型。支持的类型int8、int16、int32、float16和float32。
- **group** (str可选) - 工作通信组。默认值“hccl_world_group/nccl_world_group”。
输入:
- **input_x** (Tensor) - 输入Tensor其shape为 :math:`(x_1, x_2, ..., x_R)`

View File

@ -0,0 +1,9 @@
mindspore.ops.Renorm
=====================
.. py:class:: mindspore.ops.Renorm(p, dim, maxnorm)
沿维度 `dim` 重新规范输入 `input_x` 的子Tensor并且每个子Tensor的p范数不超过给定的最大范数 `maxnorm` 。如果子Tensor的p范数小于 `maxnorm` 则当前子Tensor不需要修改否则该子Tensor需要修改为对应位置的原值除以该子Tensor的p范数然后再乘上 `maxnorm`
更多参考详见 :func:`mindspore.ops.renorm`

View File

@ -0,0 +1,31 @@
mindspore.ops.ResizeArea
=========================
.. py:class:: mindspore.ops.ResizeArea(align_corners=False)
使用面积插值调整图像大小到指定的大小。
调整过程只改变输入图像的高和宽维度数据。
.. warning::
`size` 的值必须大于0。
参数:
- **align_corners** (bool可选) - 如果为True则输入输出图像四个角像素的中心被对齐同时保留角像素处的值。默认值False。
输入:
- **images** (Tensor) -输入图像为四维的Tensor其shape为 :math:`(batch, channels, height, width)` 支持的数据类型有int8、int16、int32、int64、float16、float32、float64、uint8和uint16。
- **size** (Tensor) - 必须为含有两个元素的一维的Tensor分别为new_height, new_width表示输出图像的高和宽。支持的数据类型为int32。
输出:
Tensor调整大小后的图像。shape为 :math:`(batch, new\_height, new\_width, channels)` 的四维Tensor数据类型为float32。
异常:
- **TypeError** - `images` 的数据类型不支持。
- **TypeError** - `size` 不是int32。
- **TypeError** - `align_corners` 不是bool。
- **ValueError** - 输入个数不是2。
- **ValueError** - `images` 的维度不是4。
- **ValueError** - `size` 的维度不是1。
- **ValueError** - `size` 含有元素个数2。
- **ValueError** - `size` 的元素不全是正数。

View File

@ -0,0 +1,31 @@
mindspore.ops.ResizeBicubic
============================
.. py:class:: mindspore.ops.ResizeBicubic(align_corners=False, half_pixel_centers=False)
使用双三次插值调整图像大小到指定的大小。
.. warning::
输出最大长度为1000000。
参数:
- **align_corners** (bool可选) - 如果为True则输入输出图像四个角像素的中心被对齐同时保留角像素处的值。默认值False。
- **half_pixel_centers** (bool可选) - 是否使用半像素中心对齐。如果设置为True那么 `align_corners` 应该设置为False。默认值False。
输入:
- **images** (Tensor) -输入图像为四维的Tensor其shape为 :math:`(batch, height, width, channels)` 支持的数据类型有int8、int16、int32、int64、float16、float32、float64、uint8和uint16。
- **size** (Tensor) - 必须为含有两个元素的一维的Tensor分别为new_height, new_width表示输出图像的高和宽。支持的数据类型为int32。
输出:
Tensor调整大小后的图像。shape为 :math:`(batch, new\_height, new\_width, channels)` 的四维Tensor数据类型为float32。
异常:
- **TypeError** - `images` 的数据类型不支持。
- **TypeError** - `size` 不是int32。
- **TypeError** - `align_corners` 不是bool。
- **TypeError** - `half_pixel_centers` 不是bool。
- **ValueError** - `images` 的维度不是4。
- **ValueError** - `size` 的维度不是1。
- **ValueError** - `size` 含有元素个数数不是2。
- **ValueError** - `size` 的元素不全是正数。
- **ValueError** - `align_corners``half_pixel_centers` 同时为True。

View File

@ -0,0 +1,31 @@
mindspore.ops.ResizeBilinearV2
===============================
.. py:class:: mindspore.ops.ResizeBicubic(align_corners=False, half_pixel_centers=False)
使用双线性插值调整图像大小到指定的大小。
调整过程只改变输入图像最低量维度的数据,分别代表高和宽。
.. warning::
在CPU后端不支持将 `half_pixel_centers` 设为True。
参数:
- **align_corners** (bool可选) - 如果为True则使用比例 :math:`(new\_height - 1) / (height - 1)` 对输入进行缩放此时输入图像和输出图像的四个角严格对齐。如果为False使用比例 :math:`new\_height / height` 输入进行缩放。默认值False。
- **half_pixel_centers** (bool可选) - 是否使用半像素中心对齐。如果设置为True那么 `align_corners` 应该设置为False。默认值False。
输入:
- **x** (Tensor) -输入图像为四维的Tensor其shape为 :math:`(batch, channels, height, width)` 支持的数据类型有float16、float32。
- **size** (Union[tuple[int], list[int], Tensor]) - 调整后图像的尺寸。为含有两个元素的一维的Tensor或者list或者tuple分别为 :math:`(new\_height, new\_width)`
输出:
Tensor调整大小后的图像。shape为 :math:`(batch, channels, new\_height, new\_width)` 的四维Tensor数据类型与 `x` 一致。
异常:
- **TypeError** - `align_corners` 不是bool。
- **TypeError** - `half_pixel_centers` 不是bool。
- **TypeError** - `align_corners``half_pixel_centers` 同时为True。
- **ValueError** - `half_pixel_centers` 为True同时运行平台为CPU。
- **ValueError** - `x` 维度不是4。
- **ValueError** - `size` 为Tensor且维度不是1。
- **ValueError** - `size` 含有元素个数不是2。

View File

@ -18,4 +18,4 @@ mindspore.ops.gamma
- **TypeError** - `shape` 不是tuple。
- **TypeError** - `alpha``beta` 不是Tensor。
- **TypeError** - `seed` 的数据类型不是int。
- **TypeError** - `alpha``beta` 的数据类型不是float32。
- **TypeError** - `alpha``beta` 的数据类型不是float32。

View File

@ -3,19 +3,19 @@ mindspore.ops.random_gamma
.. py:function:: mindspore.ops.random_gamma(shape, alpha, seed=0, seed2=0)
根据伽马分布产生成随机数。
根据伽马分布产生成随机数。
参数:
- **shape** (Tensor) - 指定生成随机数的shape。任意维度的Tensor。
- **alpha** (Tensor) - :math:`\alpha` 分布的参数。应该大于0且数据类型为half、float32或者float64。
- **seed** (int) - 随机数生成器的种子必须是非负数默认为0。
- **seed2** (int) - 随机数生成器的种子必须是非负数默认为0。
参数:
- **shape** (Tensor) - 指定生成随机数的shape。任意维度的Tensor。
- **alpha** (Tensor) - :math:`\alpha` 分布的参数。应该大于0且数据类型为half、float32或者float64。
- **seed** (int) - 随机数生成器的种子必须是非负数默认为0。
- **seed2** (int) - 随机数生成器的种子必须是非负数默认为0。
返回:
Tensor。shape是输入 `shape``alpha` 拼接后的shape。数据类型和alpha一致。
返回:
Tensor。shape是输入 `shape``alpha` 拼接后的shape。数据类型和alpha一致。
异常:
- **TypeError** `shape` 不是Tensor。
- **TypeError** `alpha` 不是Tensor。
- **TypeError** `seed` 的数据类型不是int。
- **TypeError** `alpha` 的数据类型不是half、float32或者float64。
异常:
- **TypeError** `shape` 不是Tensor。
- **TypeError** `alpha` 不是Tensor。
- **TypeError** `seed` 的数据类型不是int。
- **TypeError** `alpha` 的数据类型不是half、float32或者float64。

View File

@ -213,7 +213,11 @@ Image Processing
mindspore.ops.L2Normalize
mindspore.ops.NMSWithMask
mindspore.ops.NonMaxSuppressionWithOverlaps
mindspore.ops.PSROIPooling
mindspore.ops.RGBToHSV
mindspore.ops.ResizeArea
mindspore.ops.ResizeBicubic
mindspore.ops.ResizeBilinearV2
mindspore.ops.ROIAlign
mindspore.ops.SampleDistortedBoundingBoxV2
mindspore.ops.ScaleAndTranslate
@ -396,6 +400,7 @@ Linear Algebraic Operator
mindspore.ops.Orgqr
mindspore.ops.Svd
mindspore.ops.TridiagonalMatMul
mindspore.ops.Qr
Tensor Operation Operator
--------------------------
@ -430,8 +435,11 @@ Random Generation Operator
mindspore.ops.LogNormalReverse
mindspore.ops.Multinomial
mindspore.ops.NonDeterministicInts
mindspore.ops.ParameterizedTruncatedNormal
mindspore.ops.RandomCategorical
mindspore.ops.RandomChoiceWithMask
mindspore.ops.RandomGamma
mindspore.ops.RandomPoisson
mindspore.ops.Randperm
mindspore.ops.StandardLaplace
mindspore.ops.StandardNormal
@ -508,8 +516,10 @@ Array Operation
mindspore.ops.Nonzero
mindspore.ops.ParallelConcat
mindspore.ops.PopulationCount
mindspore.ops.RaggedRange
mindspore.ops.Range
mindspore.ops.Rank
mindspore.ops.Renorm
mindspore.ops.Reshape
mindspore.ops.ResizeNearestNeighborV2
mindspore.ops.ReverseSequence

View File

@ -258,22 +258,30 @@ class DynamicAssign(PrimitiveWithCheck):
class PadAndShift(PrimitiveWithCheck):
"""
Pad a tensor with -1, and shift with a length.
Initialize a tensor with -1, and copy a slice from `input_x` to the padded Tensor.
Note:
If use python, PadAndShift is:
output = [-1] * cum_sum_arr[-1]
start = cum_sum_arr[shift_idx]
end = cum_sum_arr[shift_idx + 1]
output[start:end] = input_x[:(end-start)]
Inputs:
- **input_x** (Tensor) - The input Tensor, which will be copied
to `output`.
- **cum_sum_arr** (Tensor) - The last value of cum_sum_arr is
the pad length of output tensor, cum_sum_arr[shift_idx] is
the start to shift, and cum_sum_arr[shift_idx+1] is the end.
- **shift_idx** (Int) - The idx of cum_sum_arr.
if use python, PadAndShift is:
output = [-1] * cum_sum_arr[-1]
start = cum_sum_arr[shift_idx]
end = cum_sum_arr[shift_idx + 1]
output[start:end] = input_x[:(end-start)]
the pad length of output tensor, `cum_sum_arr[shift_idx]` is
the start to shift, and `cum_sum_arr[shift_idx+1]` is the end.
- **shift_idx** (int) - The idx of `cum_sum_arr` .
Outputs:
Tensor, has the same type as original `variable`.
- **output** (Tensor) - Tensor, has the same type as `input`.
Raises:
TypeError: `input_x` or `cum_sum_arr` is not Tensor.
TypeError: `shift_idx` is not int.
ValueError: Value of `shift_idx` is larger than or equal to the length of `cum_sum_arr` .
Supported Platforms:
`CPU`

View File

@ -460,7 +460,7 @@ class Send(PrimitiveWithInfer):
class Receive(PrimitiveWithInfer):
"""
receive tensors from src_rank.
Receive tensors from src_rank.
Note:
Send and Receive must be used in combination and have same sr_tag.
@ -473,7 +473,7 @@ class Receive(PrimitiveWithInfer):
shape (list[int]): A required list identifying the shape of the tensor to be received.
dtype (Type): A required Type identifying the type of the tensor to be received. The supported types:
int8, int16, int32, float16, float32.
group (str): The communication group to work on. Default: "hccl_world_group/nccl_world_group".
group (str, optional): The communication group to work on. Default: "hccl_world_group/nccl_world_group".
Inputs:
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.

View File

@ -686,6 +686,9 @@ class ResizeBilinearV2(Primitive):
The resizing only affects the lower two dimensions which represent the height and width.
.. warning::
On CPU, setting `half_pixel_centers` to True is currently not supported.
Args:
align_corners (bool, optional): If true, rescale input by :math:`(new\_height - 1) / (height - 1)`,
which exactly aligns the 4 corners of images and resized images. If false,
@ -708,6 +711,9 @@ class ResizeBilinearV2(Primitive):
TypeError: If `half_pixel_centers` is not a bool.
TypeError: If `align_corners` and `half_pixel_centers` are all True.
ValueError: If `half_pixel_centers` is True and device_target is CPU.
ValueError: If dim of `x` is not 4.
ValueError: If `size` is Tensor and its dim is not 1.
ValueError: If `size` contains other than 2 elements.
Supported Platforms:
``Ascend`` ``CPU`` ``GPU``
@ -739,37 +745,39 @@ class ResizeBilinearV2(Primitive):
class ResizeBicubic(Primitive):
"""
r"""
Resize images to size using bicubic interpolation.
.. warning::
The max output length is 1000000.
Args:
align_corners (bool):If true, the centers of the 4 corner pixels of the input
align_corners (bool, optional):If true, the centers of the 4 corner pixels of the input
and output tensors are aligned, preserving the values at the corner pixels.Default: False.
half_pixel_centers (bool): An optional bool. Default: False.
half_pixel_centers (bool): Whether to use half-pixel center alignment. If set to True,
`align_corners` should be False. Default: False.
Inputs:
- **images** (Tensor) - The input image must be a 4-D tensor of shape [batch, height, width, channels].
- **images** (Tensor) - The input image must be a 4-D tensor of shape :math:`(batch, height, width, channels)`.
The format must be NHWC.
Types allowed: int8, int16, int32, int64, float16, float32, float64, uint8, uint16.
- **size** (Tensor) - A 1-D tensor of shape [2], with 2 elements: new_height, new_width.
Types allowed: int32.
Outputs:
A 4-D tensor of shape [batch, new_height, new_width, channels] with type: float32.
A 4-D tensor of shape :math:`(batch, new\_height, new\_width, channels)` with type float32.
Raises:
TypeError: If `images` type is not allowed.
TypeError: If `size` type is not allowed.
TypeError: If `align_corners` type is not allowed.
TypeError: If `half_pixel_centers` type is not allowed.
TypeError: If `size` type is not int32.
TypeError: If `align_corners` type is not bool.
TypeError: If `half_pixel_centers` type is not bool.
ValueError: If `images` dim is not 4.
ValueError: If `size` dim is not 1.
ValueError: If `size` size is not 2.
ValueError: If `size` value is not positive.
ValueError: If `align_corners` and `half_pixel_centers` value are both true.
ValueError: If any `size` value is not positive.
ValueError: If `align_corners` and `half_pixel_centers` value are both True.
Supported Platforms:
@ -848,32 +856,32 @@ class ResizeArea(Primitive):
The resizing process only changes the two dimensions of images, which represent the width and height of images.
.. warning::
The values of "size" must be greater than zero.
The values of `size` must be greater than zero.
Args:
align_corners (bool, optional): If true, the centers of the 4 corner pixels of the input and output
tensors are aligned, preserving the values at the corner pixels. Defaults: False.
Inputs:
- **images** (Tensor) - Input images must be a 4-D tensor with shape which is [batch, height, width, channels].
The format must be NHWC.
- **images** (Tensor) - Input images must be a 4-D tensor with shape
which is :math:`(batch, channels, height, width)`. The format must be NHWC.
Types allowed: int8, int16, int32, int64, float16, float32, float64, uint8, uint16.
- **size** (Tensor) - Input size must be a 1-D tensor of 2 elements: new_height, new_width.
The new size of output image.
Types allowed: int32.
Outputs:
A 4-D tensor of shape [batch, new_height, new_width, channels] with type: float32.
A 4-D tensor of shape :math:`(batch, new_height, new_width, channels)` with type float32.
Raises:
TypeError: If dtype of `images` is not supported.
TypeError: If dtype of `size` is not int32.
TypeError: If dtype of `align_corners` is not bool.
ValueError: If the num of inputs is not 2.
ValueError: If the dimension of `images` shape is not 4.
ValueError: If the dimension of `size` shape is not 1.
ValueError: If the element num of `size` is not [new_height, new_width].
ValueError: The size is not positive.
ValueError: If the dimension of `images` is not 4.
ValueError: If the dimension of `size` is not 1.
ValueError: If the element num of `size` is not 2.
ValueError: If any value of `size` is not positive.
Supported Platforms:
``Ascend`` ``CPU`` ``GPU``

View File

@ -6626,12 +6626,12 @@ class RaggedRange(Primitive):
- **rt_dense_values** (Tensor) - The dense values of the return `RaggedTensor`,
and type of the tensor should be same as input `starts`.
Let size of input `starts`, input `limits` and input `deltas` are i,
if type of the input `starts`, input `limits` and input `deltas`
are int32 or int64, shape of the output `rt_dense_values` is equal to
sum(abs(limits[i] - starts[i]) + abs(deltas[i]) - 1) / abs(deltas[i])),
if type of the input `starts`, input `limits` and input `deltas`
are float32 or float64, shape of the output `rt_dense_values` is equal to
sum(ceil(abs((limits[i] - starts[i]) / deltas[i]))).
- if type of the input `starts`, input `limits` and input `deltas`
are int32 or int64, shape of the output `rt_dense_values` is equal to
sum(abs(limits[i] - starts[i]) + abs(deltas[i]) - 1) / abs(deltas[i])),
- if type of the input `starts`, input `limits` and input `deltas`
are float32 or float64, shape of the output `rt_dense_values` is equal to
sum(ceil(abs((limits[i] - starts[i]) / deltas[i]))).
Raises:
TypeError: If any input is not Tensor.
TypeError: If the type of `starts` is not one of the following dtype: int32, int64, float32, float64.
@ -6921,7 +6921,7 @@ class Renorm(Primitive):
`maxnorm`. Otherwise the sub-tensor needs to be modified to the original value of the corresponding position
divided by the p-norm of the substensor and then multiplied by `maxnorm`.
Refer to :func::`mindspore.ops.renorm` for more detail.
Refer to :func::`mindspore.ops.renorm` for more details.
Supported Platforms:
``Ascend`` ``CPU`` ``GPU``
@ -7650,7 +7650,7 @@ class Qr(Primitive):
If False (the default), compute the P columns of q where P is minimum of the 2 innermost dimensions of x.
Args:
full_matrices (bool): The default value is Fasle.
- **full_matrices** (bool, optional) - Whether compute full-sized QR decomposition. Default: False.
Inputs:
- **x** (Tensor) - A matrix to be calculated. The matrix must be at least two dimensions.
@ -7659,10 +7659,10 @@ class Qr(Primitive):
Outputs:
- **q** (Tensor) - The orthonormal matrices of x.
If `full_matrices` is true, the shape is (m, m), else the shape is (m, p).
If `full_matrices` is true, the shape is :math:`(m, m)`, else the shape is :math:`(m, p)`.
The dtype of `q` is same as `x`.
- **r** (Tensor) - The upper triangular matrices of x.
If `full_matrices` is true, the shape is (m, n), else the shape is (p, n).
If `full_matrices` is true, the shape is :math:`(m, n)`, else the shape is :math:`(p, n)`.
The dtype of `r` is same as `x`.
Raises:

View File

@ -9379,7 +9379,7 @@ class NthElement(Primitive):
class PSROIPooling(Primitive):
r"""
Position Sensitive ROI-Pooling
Applies Position Sensitive ROI-Pooling on input Tensor.
Args:
spatial_scale (float): a scaling factor that maps the box coordinates to the input coordinates.
@ -9400,7 +9400,16 @@ class PSROIPooling(Primitive):
0 <= x1 < x2 and 0 <= y1 < y2.
Outputs:
- out (rois.shape[0] * rois.shape[2], output_dim, group_size, group_size), the result after pooling.
- **out** (Tensor) - The result after pooling. Its shape
is :math:`(rois.shape[0] * rois.shape[2], output\_dim, group\_size, group\_size)`.
Raises:
TypeError: If `spatial_scale` is not a float.
TypeError: If `group_size` or `output_dim` is not an int.
TypeError: If `features` or `rois` is not a Tensor.
TypeError: If dtype of `rois` is not float16 or float32.
ValueError: If shape of `features` does not satisfy :math:`(C == output_dim * group_size * group_size)`.
ValueError: If `spatial_scale` is negative.
Supported Platforms:
``Ascend``

View File

@ -225,9 +225,10 @@ class RandomGamma(Primitive):
operator-level random seed.
Args:
seed (int): The operator-level random seed, used to generate random numbers, must be non-negative. Default: 0.
seed2 (int): The global random seed and it will combile with the operator-level random seed to determine the
final generated random number, must be non-negative. Default: 0.
seed (int, optional): The operator-level random seed, used to generate random numbers,
must be non-negative. Default: 0.
seed2 (int, optional): The global random seed and it will combile with the operator-level
random seed to determine the final generated random number, must be non-negative. Default: 0.
Inputs:
- **shape** (Tensor) - The shape of random tensor to be generated.
@ -438,38 +439,38 @@ class Gamma(PrimitiveWithInfer):
class ParameterizedTruncatedNormal(Primitive):
"""
Returns a tensor of the specified shape filled with truncated normal values.
When 'shape' is (batch_size, *), the shape of 'mean', 'stdevs', 'min', 'max' should be () or (batch_size, ).
When `shape` is :math:`(batch_size, *)`, the shape of `mean`, `stdevs`,
`min` and `max` should be :math:`()` or :math:`(batch_size, )`.
Note:
The number in tensor minval must be strictly less than maxval at any position after broadcasting.
The value in tensor `min` must be strictly less than `max` at any position after broadcasting.
Args:
seed (int): An optional int. Defaults to 0. If either `seed` or `seed2` are set to be non-zero,
the seed is set by the given seed. Otherwise, it is seeded by a random seed.
seed2 (int): An optional int. Defaults to 0. A second seed to avoid seed collision.
seed (int, optional): Random number seed. If either `seed` or `seed2` are set to be non-zero,
the seed is set by the given seed. Otherwise, it is seeded by a random seed. Default: 0.
seed2 (int, optional): A second seed to avoid seed collision. Default: 0.
Inputs:
- **shape** (Tensor) - The shape of random tensor to be generated. Its type must be one of the following types:
int32 and int64.
- **mean** (Tensor) - A Tensor. The parameter defines the mean of truncated normal distribution.
- **mean** (Tensor) - The parameter defines the mean of truncated normal distribution.
Its type must be one of the following types:float16, float32, float64.
- **stdevs** (Tensor) - A Tensor. The parameter defines the standard deviation for truncation of
- **stdevs** (Tensor) - The parameter defines the standard deviation for truncation of
the normal distribution. It must be greater than 0 and have the same type as means.
- **min** (Tensor) - The distribution parameter, a. The parameter defines the minimum of
- **min** (Tensor) - The parameter defines the minimum of
truncated normal distribution. It must have the same type as means.
- **max** (Tensor) - The distribution parameter, b. The parameter defines the maximum of
- **max** (Tensor) - The parameter defines the maximum of
truncated normal distribution. It must have the same type as means.
Outputs:
Tensor. Its shape is specified by the input `shape` and it must have the same type as means.
Raises:
TypeError: If `shape`, `mean`, `stdevs`, `min`, `max` and input tensor type are not allowed.
TypeError: If data type of `shape`, `mean`, `stdevs`, `min` and `max` are not allowed.
TypeError: If `mean`, `stdevs`, `min`, `max` don't have the same type.
TypeError: If `mean` or `stdevs` or `minval` or `maxval` is not a Tensor.
ValueError: When 'shape' is (batch_size, *), if the shape of 'mean', 'stdevs', 'min', 'max'
is not () or (batch_size, ).
TypeError: If any of `shape`, `mean`, `stdevs`, `min` and `max` is not Tensor.
ValueError: When `shape` is :math:`(batch_size, *)`, if the shape of `mean`, `stdevs`, `min` or `max`
is not :math:`()` or :math:`(batch_size, )`.
ValueError: If `shape` elements are not positive.
ValueError: If `stdevs` elements are not positive.
ValueError: If `shape` has less than 2 elements.
@ -569,25 +570,26 @@ class Poisson(PrimitiveWithInfer):
class RandomPoisson(Primitive):
r"""
Produces random non-negative values i, distributed according to discrete probability function:
Produces random non-negative values i, distributed according to discrete probability function:
.. math::
\text{P}(i|μ) = \frac{\exp(-μ)μ^{i}}{i!},
\text{P}(i|μ) = \frac{\exp(-μ)μ^{i}}{i!}
Args:
seed (int): An optional int. Defaults to 0. If either `seed` or `seed2` are set to be non-zero,
the seed is set by the given seed. Otherwise, it is seeded by a random seed.
seed2 (int): An optional int. Defaults to 0. A second seed to avoid seed collision.
dtype (mindspore.dtype): The type of output. Default: mindspore.int64.
seed (int, optional): Random number seed. If either `seed` or `seed2` are set to be non-zero,
the seed is set by the given seed. Otherwise, it is seeded by a random seed. Default: 0.
seed2 (int, optional): A second seed to avoid seed collision. Default: 0.
dtype (mindspore.dtype, optional): The type of output. Default: mindspore.int64.
Inputs:
- **shape** (Tensor) - The shape of random tensor to be generated, 1-D Tensor, whose dtype must be in
[int32, int64]
[int32, int64].
- **rate** (Tensor) - μ parameter the distribution was constructed with. The parameter defines mean number
of occurrences of the event. Its type must be in [float16, float32, float64, int32, int64]
of occurrences of the event. Its type must be in [float16, float32, float64, int32, int64].
Outputs:
Tensor. Its shape is (*shape, *rate.shape). Its type is specified by `dtype`.
Tensor. Its shape is :math:`(*shape, *rate.shape)`. Its type is specified by `dtype`.
Raises:
TypeError: If `shape` is not a Tensor or its dtype is not int32 or int64.