fix amp docs
This commit is contained in:
parent
e9e62b92ec
commit
4e577ce104
|
@ -12,6 +12,7 @@ mindspore.amp.all_finite
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
- **inputs** (Union(tuple(Tensor), list(Tensor))) - 可迭代的Tensor。
|
- **inputs** (Union(tuple(Tensor), list(Tensor))) - 可迭代的Tensor。
|
||||||
|
- **status** (Tensor) - 溢出检测时所需要的初始状态,仅在Ascend需要。默认值:None。
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
Tensor,布尔类型的标量Tensor。
|
Tensor,布尔类型的标量Tensor。
|
||||||
|
|
|
@ -63,7 +63,7 @@ def init_status():
|
||||||
Returns a Tensor indicating initialized status for overflow detection.
|
Returns a Tensor indicating initialized status for overflow detection.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
Only Ascend need status to capture overflow status, you can alse call
|
Only Ascend need status to capture overflow status, you can also call
|
||||||
this function on GPU or CPU, but the return value is useless.
|
this function on GPU or CPU, but the return value is useless.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -73,7 +73,7 @@ def init_status():
|
||||||
``Ascend`` ``GPU`` ``CPU``
|
``Ascend`` ``GPU`` ``CPU``
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> status = init_status()
|
>>> status = amp.init_status()
|
||||||
"""
|
"""
|
||||||
if _ascend_target():
|
if _ascend_target():
|
||||||
status = ops.NPUAllocFloatStatus()()
|
status = ops.NPUAllocFloatStatus()()
|
||||||
|
@ -98,6 +98,8 @@ def all_finite(inputs, status=None):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
inputs (Union(tuple(Tensor), list(Tensor))): a iterable Tensor.
|
inputs (Union(tuple(Tensor), list(Tensor))): a iterable Tensor.
|
||||||
|
status (Tensor): the status Tensor for overflow detection, only required on
|
||||||
|
Ascend. Default: None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor, a scalar Tensor and the dtype is bool.
|
Tensor, a scalar Tensor and the dtype is bool.
|
||||||
|
@ -107,7 +109,7 @@ def all_finite(inputs, status=None):
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> x = (Tensor(np.array([np.log(-1), 1, np.log(0)])), Tensor(np.array([1.0]))
|
>>> x = (Tensor(np.array([np.log(-1), 1, np.log(0)])), Tensor(np.array([1.0]))
|
||||||
>>> output = all_finite(x)
|
>>> output = amp.all_finite(x)
|
||||||
"""
|
"""
|
||||||
if _ascend_target():
|
if _ascend_target():
|
||||||
if status is None:
|
if status is None:
|
||||||
|
|
|
@ -58,17 +58,17 @@ class FixedLossScaleManager(LossScaleManager):
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> import mindspore as ms
|
>>> import mindspore as ms
|
||||||
>>> from mindspore import nn
|
>>> from mindspore import amp, nn
|
||||||
>>>
|
>>>
|
||||||
>>> net = Net()
|
>>> net = Net()
|
||||||
>>> #1) Drop the parameter update if there is an overflow
|
>>> #1) Drop the parameter update if there is an overflow
|
||||||
>>> loss_scale_manager = ms.FixedLossScaleManager()
|
>>> loss_scale_manager = amp.FixedLossScaleManager()
|
||||||
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||||
>>> model = ms.Model(net, loss_scale_manager=loss_scale_manager, optimizer=optim)
|
>>> model = ms.Model(net, loss_scale_manager=loss_scale_manager, optimizer=optim)
|
||||||
>>>
|
>>>
|
||||||
>>> #2) Execute parameter update even if overflow occurs
|
>>> #2) Execute parameter update even if overflow occurs
|
||||||
>>> loss_scale = 1024.0
|
>>> loss_scale = 1024.0
|
||||||
>>> loss_scale_manager = ms.FixedLossScaleManager(loss_scale, False)
|
>>> loss_scale_manager = amp.FixedLossScaleManager(loss_scale, False)
|
||||||
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9, loss_scale=loss_scale)
|
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9, loss_scale=loss_scale)
|
||||||
>>> model = ms.Model(net, loss_scale_manager=loss_scale_manager, optimizer=optim)
|
>>> model = ms.Model(net, loss_scale_manager=loss_scale_manager, optimizer=optim)
|
||||||
"""
|
"""
|
||||||
|
@ -133,10 +133,10 @@ class DynamicLossScaleManager(LossScaleManager):
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> import mindspore as ms
|
>>> import mindspore as ms
|
||||||
>>> from mindspore import nn
|
>>> from mindspore import amp, nn
|
||||||
>>>
|
>>>
|
||||||
>>> net = Net()
|
>>> net = Net()
|
||||||
>>> loss_scale_manager = ms.DynamicLossScaleManager()
|
>>> loss_scale_manager = amp.DynamicLossScaleManager()
|
||||||
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||||
>>> model = ms.Model(net, loss_scale_manager=loss_scale_manager, optimizer=optim)
|
>>> model = ms.Model(net, loss_scale_manager=loss_scale_manager, optimizer=optim)
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue