!45852 add CN docs for NonMaxSuppressionV3 and PadV3 and move Roll from inner…

Merge pull request !45852 from 李林杰/1122_add_a_few_CN_doc_master
This commit is contained in:
i-robot 2022-11-23 07:12:02 +00:00 committed by Gitee
commit c0445d8a4e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
15 changed files with 208 additions and 123 deletions

View File

@ -80,6 +80,7 @@ MindSpore中 `mindspore.ops` 接口与上一版本相比,新增、删除和支
mindspore.ops.NuclearNorm
mindspore.ops.Pad
mindspore.ops.Padding
mindspore.ops.PadV3
mindspore.ops.ResizeNearestNeighbor
mindspore.ops.ResizeBilinear
mindspore.ops.UpsampleNearest3D
@ -220,6 +221,7 @@ MindSpore中 `mindspore.ops` 接口与上一版本相比,新增、删除和支
mindspore.ops.IOU
mindspore.ops.L2Normalize
mindspore.ops.NMSWithMask
mindspore.ops.NonMaxSuppressionV3
mindspore.ops.NonMaxSuppressionWithOverlaps
mindspore.ops.PSROIPooling
mindspore.ops.RGBToHSV
@ -543,6 +545,7 @@ Array操作
mindspore.ops.ReverseSequence
mindspore.ops.ReverseV2
mindspore.ops.RightShift
mindspore.ops.Roll
mindspore.ops.ScatterAddWithAxis
mindspore.ops.ScatterNd
mindspore.ops.ScatterNdDiv

View File

@ -0,0 +1,35 @@
mindspore.ops.NonMaxSuppressionV3
==================================
.. py:class:: mindspore.ops.NonMaxSuppressionV3
贪婪选取一组按score降序排列后的边界框。
.. warning::
如果 `max_output_size` 小于0将使用0代替。
.. note::
- 此算法与原点在坐标系中的位置无关。
- 对于坐标系的正交变换和平移,该算法不受影响,因此坐标系的平移变换后算法会选择相同的框。
输入:
- **boxes** (Tensor) - 二维Tensorshape为 :math:`(num_boxes, 4)`
- **scores** (Tensor) - 一维Tensor其shape为 :math:`(num_boxes)` 。表示对应每一行每个方框的score值 `scores``boxes` 的num_boxes必须相等。支持的数据类型为float32。
- **max_output_size** (Union[Tensor, Number.Int]) - 选取最大的边框数必须大于等于0数据类型为int32。
- **iou_threshold** (Union[Tensor, Number.Float]) - 边框重叠值阈值重叠值大于此值说明重叠过大其值必须大于等于0小于等于1。支持的数据类型为float32。
- **score_threshold** (Union[Tensor, Number.Float]) - 移除边框阈值边框score值大于此值则移除相应边框。支持的数据类型为float32。
输出:
一维Tensor表示被选中边框的index其shape为 :math:`(M)` 其中M <= `max_output_size`
异常:
- **TypeError** - `boxes``scores` 的数据类型不一致。
- **TypeError** - `iou_threshold``score_threshold` 的数据类型不一致。
- **TypeError** - `boxes` 的数据类型不是float16或者float32。
- **TypeError** - `scores` 的数据类型不是float16或者float32。
- **TypeError** - `max_output_size` 不是Tensor或者Scalar或者其数据类型不是int32或int64。
- **TypeError** - `iou_threshold` 不是Tesnor或者Scalar或者其数据类型不是float16或float32。
- **TypeError** - `score_threshold` 不是Tesnor或者Scalar或者其数据类型不是float16或float32。
- **ValueError** - `boxes` 的shape长度不是2或者第二维度的值不是4。
- **ValueError** - `scores` shape长度不是1。
- **ValueError** - `max_output_size``iou_threshold``score_threshold` 的shape长度不是0。

View File

@ -0,0 +1,32 @@
mindspore.ops.PadV3
====================
.. py:class:: mindspore.ops.PadV3(mode="constant", paddings_contiguous=True)
根据参数 `mode``paddings_contiguous` 对输入进行填充。
参数:
- **mode** (str可选) - 填充模式,支持"constant" 、"reflect" 和 "edge"。默认值:"constant"。
- **paddings_contiguous** (bool可选) - 是否连续填充。如果为True `paddings` 格式为[begin0, end0, begin1, end1, ...]如果为False`paddings` 格式为[begin0, begin1, ..., end1, end2, ...]。默认值True。
输入:
- **x** (Tensor) - Pad的输入任意维度的Tensor。
- **paddings** (Tensor) - Pad的输入任意维度的Tensor。
- **constant_value** (Tensor) - Pad的输入任意维度的Tensor。
输出:
填充后的Tensor。
异常:
- **TypeError** - `x``paddings` 不是Tensor。
- **TypeError** - `padding_contiguous` bool。
- **ValueError** - `mode` 不是string类型或者不在支持的列表里。
- **ValueError** - `mode` 是"constant"的同时 `paddings` 元素个数不是偶数。
- **ValueError** - `mode` 是"constant"的同时 `paddings` 元素个数大于输入维度乘以2。
- **ValueError** - `mode` 是"edge"或"reflect"的同时 `paddings` 元素个数不是2、4或6。
- **ValueError** - `mode` 是"edge"或"reflect" `x` 的维度是3 `paddings` 元素个数是2。
- **ValueError** - `mode` 是"edge"或"reflect" `x` 的维度是4 `paddings` 元素个数是4。
- **ValueError** - `mode` 是"edge"或"reflect"的同时 `x` 的维度小于3。
- **ValueError** - `mode` 是"edge"的同时 `x` 的维度大于5。
- **ValueError** - `mode` 是"reflect"的同时填充值大于对应 `x` 的维度。
- **ValueError** - 填充之后输出shape数不大于零。

View File

@ -0,0 +1,8 @@
mindspore.ops.Roll
===================
.. py:class:: mindspore.ops.Roll(shift, axis)
沿轴移动Tensor的元素。
更多参考详见 :func:`mindspore.ops.roll`

View File

@ -80,6 +80,7 @@ Neural Network
mindspore.ops.Pad
mindspore.ops.EmbeddingLookup
mindspore.ops.Padding
mindspore.ops.PadV3
mindspore.ops.ResizeNearestNeighbor
mindspore.ops.ResizeBilinear
mindspore.ops.UpsampleNearest3D
@ -219,6 +220,7 @@ Image Processing
mindspore.ops.IOU
mindspore.ops.L2Normalize
mindspore.ops.NMSWithMask
mindspore.ops.NonMaxSuppressionV3
mindspore.ops.NonMaxSuppressionWithOverlaps
mindspore.ops.PSROIPooling
mindspore.ops.RGBToHSV
@ -542,6 +544,7 @@ Array Operation
mindspore.ops.ReverseSequence
mindspore.ops.ReverseV2
mindspore.ops.RightShift
mindspore.ops.Roll
mindspore.ops.ScatterAddWithAxis
mindspore.ops.ScatterNd
mindspore.ops.ScatterNdDiv

View File

@ -1656,7 +1656,7 @@ class Roll(Cell):
Validator.check_is_int(s_axis, "axis", "Roll")
for s_shift in self.shift:
Validator.check_is_int(s_shift, "shift", "Roll")
self.roll = inner.Roll(self.shift, self.axis)
self.roll = P.Roll(self.shift, self.axis)
self.gpu = True
if len(self.shift) != len(self.axis):
raise ValueError(f"For '{self.cls_name}', the shape of 'shift' and the shape of 'axis' must be "
@ -1664,14 +1664,14 @@ class Roll(Cell):
f"and the length of 'axis' {len(self.axis)}.")
else:
if not isinstance(self.axis, (list, tuple)):
self.op_list.append((inner.Roll(shift=self.shift, axis=0), self.axis))
self.op_list.append((P.Roll(shift=self.shift, axis=0), self.axis))
else:
if len(self.shift) != len(self.axis):
raise ValueError(f"For '{self.cls_name}', the shape of 'shift' and the shape of 'axis' must be "
f"the same, but got the length of 'shift' {len(self.shift)} "
f"and the length of 'axis' {len(self.axis)}.")
for idx, _ in enumerate(self.axis):
self.op_list.append((inner.Roll(shift=self.shift[idx], axis=0), self.axis[idx]))
self.op_list.append((P.Roll(shift=self.shift[idx], axis=0), self.axis[idx]))
def construct(self, input_x):
dim = len(self.shape_op(input_x))

View File

@ -16,7 +16,6 @@
"""inner_ops"""
from __future__ import absolute_import
from mindspore import context
from mindspore.ops.operations.comm_ops import _VirtualPipelineEnd
from mindspore.ops._grad.grad_base import bprop_getters
from mindspore.ops.operations import _inner_ops as inner
@ -40,27 +39,6 @@ def get_bprop_tensor_copy_slices(self):
return bprop
@bprop_getters.register(inner.Roll)
def get_bprop_roll(self):
"""Generate bprop for Roll"""
if context.get_context("device_target") == "GPU":
shift = []
axis = self.axis
for tmp in enumerate(self.shift):
shift.append(-tmp[1])
roll_grad = inner.Roll(shift, axis)
else:
shift = self.shift
axis = self.axis
roll_grad = inner.Roll(-shift, axis)
def bprop(x_input, out, dout):
dx = roll_grad(dout)
return (dx,)
return bprop
@bprop_getters.register(_VirtualPipelineEnd)
def get_bprop_virtual_pipeline_end(self):
"""Backpropagator for _VirtualPipelineEnd."""

View File

@ -17,6 +17,7 @@
import numpy as np
import mindspore.numpy as mnp
from mindspore import context
from mindspore.common import dtype as mstype
from mindspore.nn import LGamma
from mindspore.ops import functional as F
@ -129,6 +130,27 @@ def get_bprop_logit(self):
return bprop
@bprop_getters.register(P.Roll)
def get_bprop_roll(self):
"""Generate bprop for Roll"""
if context.get_context("device_target") == "GPU":
shift = []
axis = self.axis
for tmp in enumerate(self.shift):
shift.append(-tmp[1])
roll_grad = P.Roll(shift, axis)
else:
shift = self.shift
axis = self.axis
roll_grad = P.Roll(-shift, axis)
def bprop(x_input, out, dout):
dx = roll_grad(dout)
return (dx,)
return bprop
@bprop_getters.register(P.Cdist)
def get_bprop_cdist(self):
"""Generate bprop for Cdist"""

View File

@ -29,7 +29,7 @@ from mindspore.ops.operations.math_ops import STFT
from mindspore.ops.operations.math_ops import ReduceStd
from mindspore.ops.operations.math_ops import Logit
from mindspore.ops.operations.math_ops import LuUnpack
from mindspore.ops.operations._inner_ops import Roll
from mindspore.ops.operations.math_ops import Roll
from mindspore.nn import layer
from mindspore._checkparam import check_is_number
from mindspore._checkparam import Rel

View File

@ -28,7 +28,7 @@ from mindspore.ops.operations import _grad_ops, _csr_ops, _inner_ops, linalg_ops
from mindspore.ops.operations.math_ops import Median
from mindspore.ops.operations.array_ops import UniqueConsecutive, Triu
from mindspore.ops.operations.nn_ops import AdaptiveMaxPool2D
from mindspore.ops.operations._inner_ops import Roll
from mindspore.ops.operations.math_ops import Roll
from mindspore.ops.composite.array_ops import repeat_interleave
from mindspore.ops.composite.math_ops import mm

View File

@ -85,7 +85,8 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A
FFTWithSize, Heaviside, Histogram, Hypot, Lcm, LuUnpack, MatrixExp,
MatrixLogarithm, MatrixPower, MatrixSolve, MatrixTriangularSolve, ReduceStd, STFT,
NextAfter, Orgqr, Qr, RaggedRange, Digamma, Eig, EuclideanNorm, CompareAndBitpack, ComplexAbs,
CumulativeLogsumexp, Gcd, Trace, TridiagonalMatMul, TrilIndices, TriuIndices, Zeta)
CumulativeLogsumexp, Gcd, Trace, TridiagonalMatMul, TrilIndices, TriuIndices, Zeta,
Roll)
from .nn_ops import (LSTM, SGD, Adam, AdamWeightDecay, FusedSparseAdam, FusedSparseLazyAdam, AdamNoUpdateParam,
ApplyMomentum, BatchNorm, BiasAdd, Conv2D, Conv3D, Conv2DTranspose, Conv3DTranspose,
DepthwiseConv2dNative,
@ -112,7 +113,7 @@ from .nn_ops import (LSTM, SGD, Adam, AdamWeightDecay, FusedSparseAdam, FusedSpa
GridSampler3D, MaxPool3DWithArgmax, MaxUnpool2D, NuclearNorm, NthElement, MultilabelMarginLoss,
PSROIPooling, Dilation2D, DataFormatVecPermute, DeformableOffsets, FractionalAvgPool,
FractionalMaxPool, FractionalMaxPool3DWithFixedKsize, FractionalMaxPoolWithFixedKsize,
GridSampler2D, TripletMarginLoss, UpsampleNearest3D, UpsampleTrilinear3D)
GridSampler2D, TripletMarginLoss, UpsampleNearest3D, UpsampleTrilinear3D, PadV3)
from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode,
ConfusionMatrix, UpdateState, Load, StopGradient,
CheckValid, Partial, Depend, identity, Push, Pull, PyFunc, _DynamicLossScale,
@ -640,6 +641,8 @@ __all__ = [
"UniqueConsecutive",
"UnravelIndex",
"Zeta",
"PadV3",
"Roll",
]
__custom__ = [

View File

@ -1281,69 +1281,6 @@ class TensorCopySlices(Primitive):
self.init_prim_io_names(inputs=['x', 'value', 'begin', 'end', 'strides'], outputs=['y'])
class Roll(Primitive):
"""
Rolls the elements of a tensor along an axis.
The elements are shifted positively (towards larger indices) by the offset of `shift` along the dimension of `axis`.
Negative `shift` values will shift elements in the opposite direction. Elements that roll passed the last position
will wrap around to the first and vice versa. Multiple shifts along multiple axes may be specified.
Note:
This inner operation is valid only if the axis is equal to 0. If the shift and the axis are tuples or lists,
this inner operation is valid only for the first pair of elements.
Args:
shift (Union[list(int), tuple(int), int]): Specifies the number of places by which elements are shifted
positively (towards larger indices) along the specified dimension. Negative shifts will roll the elements
in the opposite direction.
axis (Union[list(int), tuple(int), int]): Specifies the dimension indexes of shape to be rolled. The value is
forced to be zero in this operation.
Inputs:
- **input_x** (Tensor) - Input tensor.
Outputs:
Tensor, has the same shape and type as `input_x`.
Raises:
TypeError: If `shift` is not an int, a tuple or a list.
TypeError: If `axis` is not an int, a tuple or a list.
Supported Platforms:
``Ascend`` ``GPU``
Examples:
>>> from mindspore.ops.operations import _inner_ops as inner
>>> input_x = Tensor(np.array([0, 1, 2, 3, 4]).astype(np.float32))
>>> op = inner.Roll(shift=2, axis=0)
>>> output = op(input_x)
>>> print(output)
[3. 4. 0. 1. 2.]
>>> input_x = Tensor(np.array([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]).astype(np.float32))
>>> op = inner.Roll(shift=-1, axis=0)
>>> output = op(input_x)
>>> print(output)
[[5. 6. 7. 8. 9.]
[0. 1. 2. 3. 4.]]
"""
@prim_attr_register
def __init__(self, shift, axis):
"""Initialize Roll"""
if context.get_context("device_target") == "GPU":
validator.check_value_type("shift", shift, [int, tuple, list], self.name)
validator.check_value_type("axis", axis, [int, tuple, list], self.name)
else:
if isinstance(shift, (tuple, list)) and isinstance(axis, (tuple, list)):
validator.check_equal_int(len(shift), 1, "shift size", self.name)
validator.check_equal_int(len(axis), 1, "shift size", self.name)
validator.check_equal_int(axis[0], 0, "axis", self.name)
elif isinstance(shift, int) and isinstance(axis, int):
validator.check_equal_int(axis, 0, "axis", self.name)
self.init_prim_io_names(inputs=['input_x'], outputs=['output'])
class DSDMatmul(PrimitiveWithInfer):
"""
The definition of the CusSquare primitive.

View File

@ -355,23 +355,23 @@ class NonMaxSuppressionV3(Primitive):
Greedily selects a subset of bounding boxes in descending order of score.
.. warning::
When input "max_output_size" is negative, it will be treated as 0.
When input `max_output_size` is negative, it will be treated as 0.
Note:
This algorithm is agnostic to where the origin is in the coordinate system.
This algorithm is invariant to orthogonal transformations and translations of the coordinate system;
thus translating or reflections of the coordinate system result in the same boxes being
selected by the algorithm.
- This algorithm is agnostic to where the origin is in the coordinate system.
- This algorithm is invariant to orthogonal transformations and translations of the coordinate system,
thus translating or reflections of the coordinate system result in the same boxes being
selected by the algorithm.
Inputs:
- **boxes** (Tensor) - A 2-D Tensor of shape [num_boxes, 4].
- **scores** (Tensor) - A 1-D Tensor of shape [num_boxes] representing a single score
corresponding to each box (each row of boxes), the num_boxes of "scores" must be equal to
the num_boxes of "boxes".
- **boxes** (Tensor) - A 2-D Tensor of shape :math:`(num_boxes, 4)`.
- **scores** (Tensor) - A 1-D Tensor of shape :math:`(num_boxes)` representing a single score
corresponding to each box (each row of boxes), the num_boxes of `scores` must be equal to
the num_boxes of `boxes`.
- **max_output_size** (Union[Tensor, Number.Int]) - A scalar integer Tensor representing the maximum
number of boxes to be selected by non max suppression.
- **iou_threshold** (Union[Tensor, Number.Float]) - A 0-D float tensor representing the threshold for
deciding whether boxes overlap too much with respect to IOU, and iou_threshold must be equal or greater
deciding whether boxes overlap too much with respect to IOU, and `iou_threshold` must be equal or greater
than 0 and be equal or smaller than 1.
- **score_threshold** (Union[Tensor, Number.Float]) - A 0-D float tensor representing the threshold for
deciding when to remove boxes based on score.
@ -381,16 +381,17 @@ class NonMaxSuppressionV3(Primitive):
where M <= max_output_size.
Raises:
TypeError: If the dtype of `boxes` and `scores` is different.
TypeError: If the dtype of `iou_threshold` and `score_threshold` is different.
TypeError: If the dtype of `boxes` and `scores` are different.
TypeError: If the dtype of `iou_threshold` and `score_threshold` are different.
TypeError: If `boxes` is not tensor or its dtype is not float16 or float32.
TypeEroor: If `scores` is not tensor or its dtype is not float16 or float32.
TypeError: If `max_output_size` is not tensor or scalar.If `max_output_size` is not int32 or int64.
TypeError: If `iou_threshold` is not tensor or scalar. If its type is not float16 or float32.
TypeError: If `score_threshold` is not tensor or scalar. If its type is not float16 or float32.
TypeError: If `max_output_size` is not tensor or scalar or its date type is not int32 or int64.
TypeError: If `iou_threshold` is not tensor or scalar or its type is neither float16 or float32.
TypeError: If `score_threshold` is not tensor or scalar or its type is neither float16 or float32.
ValueError: If the size of shape of `boxes` is not 2 or the second value of its shape is not 4.
ValueError: If the size of shape of `scores` is not 1.
ValueError: If each of the size of shape of `max_output_size`, `iou_threshold`, `score_threshold` is not 0.
ValueError: If any of the size of shape of `max_output_size`,
`iou_threshold`, `score_threshold` is not 0.
Supported Platforms:
``Ascend``

View File

@ -7505,7 +7505,7 @@ class TriuIndices(Primitive):
row (int): number of rows in the 2-D matrix.
col (int): number of columns in the 2-D matrix.
offset (int, optional): diagonal offset from the main diagonal. Default: 0.
dtype (:class:`mindspore.dtype`): The specified type of output tensor.
dtype (:class:`mindspore.dtype`, optional): The specified type of output tensor.
An optional data type of `mindspore.int32` and `mindspore.int64`. Default: `mindspore.int32`.
Outputs:
@ -7782,3 +7782,63 @@ class Ormqr(Primitive):
self.transpose = validator.check_value_type('transpose', transpose, [bool], self.name)
self.add_prim_attr('left', self.left)
self.add_prim_attr('transpose', self.transpose)
class Roll(Primitive):
"""
Rolls the elements of a tensor along an axis.
The elements are shifted positively (towards larger indices) by the offset of `shift` along the dimension of `axis`.
Negative `shift` values will shift elements in the opposite direction. Elements that roll passed the last position
will wrap around to the first and vice versa. Multiple shifts along multiple axes may be specified.
Note:
This inner operation is valid only if the axis is equal to 0. If the shift and the axis are tuples or lists,
this inner operation is valid only for the first pair of elements.
Args:
shift (Union[list(int), tuple(int), int]): Specifies the number of places by which elements are shifted
positively (towards larger indices) along the specified dimension. Negative shifts will roll the elements
in the opposite direction.
axis (Union[list(int), tuple(int), int]): Specifies the dimension indexes of shape to be rolled. The value is
forced to be zero in this operation.
Inputs:
- **input_x** (Tensor) - Input tensor.
Outputs:
Tensor, has the same shape and type as `input_x`.
Raises:
TypeError: If `shift` is not an int, a tuple or a list.
TypeError: If `axis` is not an int, a tuple or a list.
Supported Platforms:
``Ascend`` ``GPU``
Examples:
>>> input_x = Tensor(np.array([0, 1, 2, 3, 4]).astype(np.float32))
>>> op = ops.Roll(shift=2, axis=0)
>>> output = op(input_x)
>>> print(output)
[3. 4. 0. 1. 2.]
>>> input_x = Tensor(np.array([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]).astype(np.float32))
>>> op = ops.Roll(shift=-1, axis=0)
>>> output = op(input_x)
>>> print(output)
"""
@prim_attr_register
def __init__(self, shift, axis):
"""Initialize Roll"""
if context.get_context("device_target") == "GPU":
validator.check_value_type("shift", shift, [int, tuple, list], self.name)
validator.check_value_type("axis", axis, [int, tuple, list], self.name)
else:
if isinstance(shift, (tuple, list)) and isinstance(axis, (tuple, list)):
validator.check_equal_int(len(shift), 1, "shift size", self.name)
validator.check_equal_int(len(axis), 1, "shift size", self.name)
validator.check_equal_int(axis[0], 0, "axis", self.name)
elif isinstance(shift, int) and isinstance(axis, int):
validator.check_equal_int(axis, 0, "axis", self.name)
self.init_prim_io_names(inputs=['input_x'], outputs=['output'])

View File

@ -4359,14 +4359,15 @@ class Pad(Primitive):
class PadV3(Primitive):
"""
Pads the input tensor according to the paddings, mode and paddings_contiguous.
Pads the input tensor according to the paddings, `mode` and `paddings_contiguous`.
Args:
mode (str): An optional string, Defaults to "constant", indicates padding mode,
support "constant", "reflect", "edge", Defaults to "constant".
paddings_contiguous (bool): An optional bool value, Defaults to True.
mode (str, optional): An optional string indicates padding mode,
support "constant", "reflect", "edge". Default: "constant".
paddings_contiguous (bool, optional): An optional bool value indicates if the padding is paddings_contiguous.
If true, paddings is arranged as [begin0, end0, begin1, end1, ...]
If false, paddings is arranged as [begin0, begin1, ..., end1, end2, ...]
Default:True.
Inputs:
- **x** (Tensor) - Tensor of shape :math:`(N, *)`, where :math:`*` means, any number of
@ -4381,16 +4382,18 @@ class PadV3(Primitive):
TypeError: If `x` or `paddings` is not a Tensor.
TypeError: If `padding_contiguous` is not a bool.
ValueError: If `mode` is not a str or not in support modes.
ValueError: If `mode` is constant, the element's number of paddings not be even.
ValueError: If `mode` is constant, the element's number of paddings large than input dim * 2.
ValueError: If `mode` is edge or reflect, the element's number of paddings is not 2, 4 or 6.
ValueError: If `mode` is edge or reflect, x dims equal 3, the element's number of paddings is 2.
ValueError: If `mode` is edge or reflect, x dims equal 4, the element's number of paddings is 4.
ValueError: If `mode` is edge or reflect, x dims smaller than 3.
ValueError: If `mode` is edge, x dims bigger than 5.
ValueError: If `mode` is reflect, x dims bigger than 4.
ValueError: If `mode` is reflect, padding size bigger than the corresponding x dimension.
ValueError: After padding, output's shape number must be greater than 0.
ValueError: If `mode` is "constant", the element's number of `paddings` not be even.
ValueError: If `mode` is "constant", the element's number of `paddings` large than input dim * 2.
ValueError: If `mode` is "edge" or "reflect", the element's number of `paddings` is not 2, 4 or 6.
ValueError: If `mode` is "edge" or "reflect", `x` dims equals 3,
the element's number of `paddings` is 2.
ValueError: If `mode` is "edge" or "reflect", `x` dims equals 4,
the element's number of `paddings` is 4.
ValueError: If `mode` is "edge" or "reflect", `x` dims smaller than 3.
ValueError: If `mode` is "edge", x dims bigger than 5.
ValueError: If `mode` is "reflect", x dims bigger than 4.
ValueError: If `mode` is "reflect", padding size bigger than the corresponding `x` dimension.
ValueError: After padding, output's shape number is not greater than 0.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``