Merge pull request !47025 from 冯一航/fix_pad
This commit is contained in:
i-robot 2022-12-21 07:53:48 +00:00 committed by Gitee
commit 5865826e1c
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 15 additions and 4 deletions

View File

@ -10,3 +10,6 @@ mindspore.ops.numel
返回:
int。Tensor的元素的总数量。
异常:
- **TypeError** - `x` 不是Tensor。

View File

@ -12,4 +12,5 @@ mindspore.ops.positive
输入Tensor。
异常:
- **TypeError** - `x` 不是Tensor。
- **TypeError** - `x` 的dtype是bool。

View File

@ -114,7 +114,6 @@ tensor_le = P.LessEqual()
tensor_gt = P.Greater()
tensor_ge = P.GreaterEqual()
not_equal_ = P.NotEqual()
size_ = P.Size()
transpose_ = P.Transpose()
cast_ = P.Cast()
@ -601,6 +600,7 @@ def positive(x):
Tensor, self input.
Raises:
TypeError: If `x` is not a Tensor.
TypeError: If the dtype of self Tensor is bool type.
Supported Platforms:
@ -613,6 +613,8 @@ def positive(x):
>>> print(ops.positive(x))
[-5.0, 1.5, 3.0, 100.0]
"""
if not isinstance(x, (Tensor, Tensor_)):
raise TypeError(f"For positive, the input must be a Tensor, but got {type(x)}")
if x.dtype == mstype.bool_:
raise TypeError("For positive, the type of tensor can not be bool.")
return x
@ -628,6 +630,9 @@ def numel(x):
Returns:
int. A scalar representing the total of elements in the Tensor.
Raises:
TypeError: If `x` is not a Tensor.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
@ -636,7 +641,9 @@ def numel(x):
>>> print(ops.numel(input_x))
4
"""
return size_(x)
if not isinstance(x, (Tensor, Tensor_)):
raise TypeError(f"For numel, the input must be a Tensor, but got {type(x)}")
return x.size
def permute(x, dims):

View File

@ -2464,7 +2464,7 @@ def pad(input_x, padding, mode='constant', value=None):
>>> import numpy as np
>>> x = ms.Tensor(np.arange(1 * 2 * 2 * 2).reshape((1, 2, 2, 2)), dtype=ms.float64)
>>> output = ops.pad(x, [1, 0, 0, 1], mode='constant', value=6.0)
>>> print(x)
>>> print(output)
[[[[6. 0. 1.]
[6. 2. 3.]
[6. 6. 6.]]
@ -2508,7 +2508,7 @@ def pad(input_x, padding, mode='constant', value=None):
raise ValueError(f"For 'pad', the padding mode '{mode}' can not set value, but got value {value}.")
if mode == "replicate":
mode = "edge"
if padding.shape[0] == input_x.ndim + 1:
if padding.shape[0] // 2 + 1 == input_x.ndim:
input_x = input_x.expand_dims(0)
is_expand = True
out = PadV3(mode=mode, paddings_contiguous=True)(input_x, padding, value)