fix amp docs

This commit is contained in:
lvyufeng 2023-01-04 17:52:29 +08:00
parent e9e62b92ec
commit 4e577ce104
3 changed files with 11 additions and 8 deletions

View File

@ -12,6 +12,7 @@ mindspore.amp.all_finite
参数: 参数:
- **inputs** (Union(tuple(Tensor), list(Tensor))) - 可迭代的Tensor。 - **inputs** (Union(tuple(Tensor), list(Tensor))) - 可迭代的Tensor。
- **status** (Tensor) - 溢出检测时所需要的初始状态仅在Ascend需要。默认值None。
返回: 返回:
Tensor布尔类型的标量Tensor。 Tensor布尔类型的标量Tensor。

View File

@ -63,7 +63,7 @@ def init_status():
Returns a Tensor indicating initialized status for overflow detection. Returns a Tensor indicating initialized status for overflow detection.
Note: Note:
Only Ascend need status to capture overflow status, you can alse call Only Ascend need status to capture overflow status, you can also call
this function on GPU or CPU, but the return value is useless. this function on GPU or CPU, but the return value is useless.
Returns: Returns:
@ -73,7 +73,7 @@ def init_status():
``Ascend`` ``GPU`` ``CPU`` ``Ascend`` ``GPU`` ``CPU``
Examples: Examples:
>>> status = init_status() >>> status = amp.init_status()
""" """
if _ascend_target(): if _ascend_target():
status = ops.NPUAllocFloatStatus()() status = ops.NPUAllocFloatStatus()()
@ -98,6 +98,8 @@ def all_finite(inputs, status=None):
Args: Args:
inputs (Union(tuple(Tensor), list(Tensor))): a iterable Tensor. inputs (Union(tuple(Tensor), list(Tensor))): a iterable Tensor.
status (Tensor): the status Tensor for overflow detection, only required on
Ascend. Default: None.
Returns: Returns:
Tensor, a scalar Tensor and the dtype is bool. Tensor, a scalar Tensor and the dtype is bool.
@ -107,7 +109,7 @@ def all_finite(inputs, status=None):
Examples: Examples:
>>> x = (Tensor(np.array([np.log(-1), 1, np.log(0)])), Tensor(np.array([1.0])) >>> x = (Tensor(np.array([np.log(-1), 1, np.log(0)])), Tensor(np.array([1.0]))
>>> output = all_finite(x) >>> output = amp.all_finite(x)
""" """
if _ascend_target(): if _ascend_target():
if status is None: if status is None:

View File

@ -58,17 +58,17 @@ class FixedLossScaleManager(LossScaleManager):
Examples: Examples:
>>> import mindspore as ms >>> import mindspore as ms
>>> from mindspore import nn >>> from mindspore import amp, nn
>>> >>>
>>> net = Net() >>> net = Net()
>>> #1) Drop the parameter update if there is an overflow >>> #1) Drop the parameter update if there is an overflow
>>> loss_scale_manager = ms.FixedLossScaleManager() >>> loss_scale_manager = amp.FixedLossScaleManager()
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> model = ms.Model(net, loss_scale_manager=loss_scale_manager, optimizer=optim) >>> model = ms.Model(net, loss_scale_manager=loss_scale_manager, optimizer=optim)
>>> >>>
>>> #2) Execute parameter update even if overflow occurs >>> #2) Execute parameter update even if overflow occurs
>>> loss_scale = 1024.0 >>> loss_scale = 1024.0
>>> loss_scale_manager = ms.FixedLossScaleManager(loss_scale, False) >>> loss_scale_manager = amp.FixedLossScaleManager(loss_scale, False)
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9, loss_scale=loss_scale) >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9, loss_scale=loss_scale)
>>> model = ms.Model(net, loss_scale_manager=loss_scale_manager, optimizer=optim) >>> model = ms.Model(net, loss_scale_manager=loss_scale_manager, optimizer=optim)
""" """
@ -133,10 +133,10 @@ class DynamicLossScaleManager(LossScaleManager):
Examples: Examples:
>>> import mindspore as ms >>> import mindspore as ms
>>> from mindspore import nn >>> from mindspore import amp, nn
>>> >>>
>>> net = Net() >>> net = Net()
>>> loss_scale_manager = ms.DynamicLossScaleManager() >>> loss_scale_manager = amp.DynamicLossScaleManager()
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> model = ms.Model(net, loss_scale_manager=loss_scale_manager, optimizer=optim) >>> model = ms.Model(net, loss_scale_manager=loss_scale_manager, optimizer=optim)
""" """