!46329 fix max_unpool3d test sample

Merge pull request !46329 from liuchao/code_docs_unpool3d_alpha
This commit is contained in:
i-robot 2022-12-08 06:33:24 +00:00 committed by Gitee
commit d80bcf746a
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 9 additions and 6 deletions

View File

@ -914,7 +914,7 @@ def max_unpool3d(x, indices, kernel_size, stride=None, padding=0, output_size=No
:math:`[(N, C, D_{out} - stride[0], H_{out} - stride[1], W_{out} - stride[2]),
(N, C, D_{out} + stride[0], H_{out} + stride[1], W_{out} + stride[2])]`.
Outputs:
Returns:
Tensor, with shape :math:`(N, C, D_{out}, H_{out}, W_{out})` or :math:`(C, D_{out}, H_{out}, W_{out})`,
with the same data type with `x`.
@ -935,11 +935,14 @@ def max_unpool3d(x, indices, kernel_size, stride=None, padding=0, output_size=No
Examples:
>>> x = Tensor(np.array([[[[[0, 1], [8, 9]]]]]).astype(np.float32))
>>> indices= Tensor(np.array([[[[[0, 1], [2, 3]]]]]).astype(np.int64))
>>> maxunpool3d = nn.MaxUnpool3d(kernel_size=1, stride=1, padding=0)
>>> output = maxunpool3d(x, indices)
>>> print(output.asnumpy())
[[[[[0. 1.]
[8. 9.]]]]]
>>> output = ops.max_unpool3d(x, indices, kernel_size=2, stride=1, padding=0)
>>> print(output)
[[[[[0. 1. 8.]
[9. 0. 0.]
[0. 0. 0.]]
[[0. 0. 0.]
[0. 0. 0.]
[0. 0. 0.]]]]]
"""
if stride is None:
stride = 0