fix amp docs

This commit is contained in:
lvyufeng 2023-01-04 17:52:29 +08:00
parent e9e62b92ec
commit 4e577ce104
3 changed files with 11 additions and 8 deletions

View File

@ -12,6 +12,7 @@ mindspore.amp.all_finite
参数:
- **inputs** (Union(tuple(Tensor), list(Tensor))) - 可迭代的Tensor。
- **status** (Tensor) - 溢出检测时所需要的初始状态仅在Ascend需要。默认值None。
返回:
Tensor布尔类型的标量Tensor。

View File

@ -63,7 +63,7 @@ def init_status():
Returns a Tensor indicating initialized status for overflow detection.
Note:
Only Ascend need status to capture overflow status, you can alse call
Only Ascend need status to capture overflow status, you can also call
this function on GPU or CPU, but the return value is useless.
Returns:
@ -73,7 +73,7 @@ def init_status():
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> status = init_status()
>>> status = amp.init_status()
"""
if _ascend_target():
status = ops.NPUAllocFloatStatus()()
@ -98,6 +98,8 @@ def all_finite(inputs, status=None):
Args:
inputs (Union(tuple(Tensor), list(Tensor))): a iterable Tensor.
status (Tensor): the status Tensor for overflow detection, only required on
Ascend. Default: None.
Returns:
Tensor, a scalar Tensor and the dtype is bool.
@ -107,7 +109,7 @@ def all_finite(inputs, status=None):
Examples:
>>> x = (Tensor(np.array([np.log(-1), 1, np.log(0)])), Tensor(np.array([1.0]))
>>> output = all_finite(x)
>>> output = amp.all_finite(x)
"""
if _ascend_target():
if status is None:

View File

@ -58,17 +58,17 @@ class FixedLossScaleManager(LossScaleManager):
Examples:
>>> import mindspore as ms
>>> from mindspore import nn
>>> from mindspore import amp, nn
>>>
>>> net = Net()
>>> #1) Drop the parameter update if there is an overflow
>>> loss_scale_manager = ms.FixedLossScaleManager()
>>> loss_scale_manager = amp.FixedLossScaleManager()
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> model = ms.Model(net, loss_scale_manager=loss_scale_manager, optimizer=optim)
>>>
>>> #2) Execute parameter update even if overflow occurs
>>> loss_scale = 1024.0
>>> loss_scale_manager = ms.FixedLossScaleManager(loss_scale, False)
>>> loss_scale_manager = amp.FixedLossScaleManager(loss_scale, False)
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9, loss_scale=loss_scale)
>>> model = ms.Model(net, loss_scale_manager=loss_scale_manager, optimizer=optim)
"""
@ -133,10 +133,10 @@ class DynamicLossScaleManager(LossScaleManager):
Examples:
>>> import mindspore as ms
>>> from mindspore import nn
>>> from mindspore import amp, nn
>>>
>>> net = Net()
>>> loss_scale_manager = ms.DynamicLossScaleManager()
>>> loss_scale_manager = amp.DynamicLossScaleManager()
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> model = ms.Model(net, loss_scale_manager=loss_scale_manager, optimizer=optim)
"""