fix amp docs
This commit is contained in:
parent
e9e62b92ec
commit
4e577ce104
|
@ -12,6 +12,7 @@ mindspore.amp.all_finite
|
|||
|
||||
参数:
|
||||
- **inputs** (Union(tuple(Tensor), list(Tensor))) - 可迭代的Tensor。
|
||||
- **status** (Tensor) - 溢出检测时所需要的初始状态,仅在Ascend需要。默认值:None。
|
||||
|
||||
返回:
|
||||
Tensor,布尔类型的标量Tensor。
|
||||
|
|
|
@ -63,7 +63,7 @@ def init_status():
|
|||
Returns a Tensor indicating initialized status for overflow detection.
|
||||
|
||||
Note:
|
||||
Only Ascend need status to capture overflow status, you can alse call
|
||||
Only Ascend need status to capture overflow status, you can also call
|
||||
this function on GPU or CPU, but the return value is useless.
|
||||
|
||||
Returns:
|
||||
|
@ -73,7 +73,7 @@ def init_status():
|
|||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> status = init_status()
|
||||
>>> status = amp.init_status()
|
||||
"""
|
||||
if _ascend_target():
|
||||
status = ops.NPUAllocFloatStatus()()
|
||||
|
@ -98,6 +98,8 @@ def all_finite(inputs, status=None):
|
|||
|
||||
Args:
|
||||
inputs (Union(tuple(Tensor), list(Tensor))): a iterable Tensor.
|
||||
status (Tensor): the status Tensor for overflow detection, only required on
|
||||
Ascend. Default: None.
|
||||
|
||||
Returns:
|
||||
Tensor, a scalar Tensor and the dtype is bool.
|
||||
|
@ -107,7 +109,7 @@ def all_finite(inputs, status=None):
|
|||
|
||||
Examples:
|
||||
>>> x = (Tensor(np.array([np.log(-1), 1, np.log(0)])), Tensor(np.array([1.0]))
|
||||
>>> output = all_finite(x)
|
||||
>>> output = amp.all_finite(x)
|
||||
"""
|
||||
if _ascend_target():
|
||||
if status is None:
|
||||
|
|
|
@ -58,17 +58,17 @@ class FixedLossScaleManager(LossScaleManager):
|
|||
|
||||
Examples:
|
||||
>>> import mindspore as ms
|
||||
>>> from mindspore import nn
|
||||
>>> from mindspore import amp, nn
|
||||
>>>
|
||||
>>> net = Net()
|
||||
>>> #1) Drop the parameter update if there is an overflow
|
||||
>>> loss_scale_manager = ms.FixedLossScaleManager()
|
||||
>>> loss_scale_manager = amp.FixedLossScaleManager()
|
||||
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
>>> model = ms.Model(net, loss_scale_manager=loss_scale_manager, optimizer=optim)
|
||||
>>>
|
||||
>>> #2) Execute parameter update even if overflow occurs
|
||||
>>> loss_scale = 1024.0
|
||||
>>> loss_scale_manager = ms.FixedLossScaleManager(loss_scale, False)
|
||||
>>> loss_scale_manager = amp.FixedLossScaleManager(loss_scale, False)
|
||||
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9, loss_scale=loss_scale)
|
||||
>>> model = ms.Model(net, loss_scale_manager=loss_scale_manager, optimizer=optim)
|
||||
"""
|
||||
|
@ -133,10 +133,10 @@ class DynamicLossScaleManager(LossScaleManager):
|
|||
|
||||
Examples:
|
||||
>>> import mindspore as ms
|
||||
>>> from mindspore import nn
|
||||
>>> from mindspore import amp, nn
|
||||
>>>
|
||||
>>> net = Net()
|
||||
>>> loss_scale_manager = ms.DynamicLossScaleManager()
|
||||
>>> loss_scale_manager = amp.DynamicLossScaleManager()
|
||||
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
>>> model = ms.Model(net, loss_scale_manager=loss_scale_manager, optimizer=optim)
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue