!45653 add note for MaskedFill

Merge pull request !45653 from 范吉斌/code_docs_maskefill
This commit is contained in:
i-robot 2022-11-22 08:51:15 +00:00 committed by Gitee
commit 4bd89d1e5f
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 13 additions and 2 deletions

View File

@ -5,4 +5,8 @@ mindspore.ops.MaskedFill
将掩码位置为True的位置填充指定的值。
.. note::
如果 `value` 是Python类型的浮点数则默认会转为float32类型。这种情况下`input_x` 为float16类型时在CPU和Ascend平台上`input_x` 会转为float32类型参与计算
并将结果类型转换到float16类型可能会造成一定性能损耗而在GPU平台上则会引起TypeError。因此建议 `value` 采用与 `input_x` 具有相同数据类型的Tensor。
更多参考详见 :func:`mindspore.ops.masked_fill`

View File

@ -1226,7 +1226,7 @@ class AdaptiveMaxPool3d(Cell):
ValueError: If `output_size` is neither an int nor a tuple with shape (3,).
Supported Platforms:
``GPU``
``GPU`` ``CPU``
Examples:
>>> x = Tensor(np.arange(0,36).reshape((1, 3, 3, 4)).astype(np.float32))

View File

@ -655,7 +655,7 @@ def adaptive_max_pool3d(x, output_size, return_indices=False):
ValueError: If `output_size` is neither an int nor a tuple with shape (3,).
Supported Platforms:
``GPU``
``GPU`` ``CPU``
Examples:
>>> x = Tensor(np.arange(0,36).reshape((1, 3, 3, 4)).astype(np.float32))

View File

@ -5987,6 +5987,13 @@ class MaskedFill(Primitive):
"""
Fills elements with value where mask is True.
Note:
If `value` is a floating-point number of Python, it will be converted to float32 later by default.
In this case, if `input_x` is a float16 Tensor, it will be converted to float32 for calculation,
and the result type will be converted back to float16 on the CPU and Ascend platforms, which may
cause the performance penalty. A TypeError may be raised on the GPU platform. Therefore,
it is recommended that 'value' should use a Tensor with the same dtype as `input_x`.
Refer to :func:`mindspore.ops.masked_fill` for more details.
Supported Platforms: