From 4e577ce104b4d427244e811fed7eabe5729d3b20 Mon Sep 17 00:00:00 2001 From: lvyufeng Date: Wed, 4 Jan 2023 17:52:29 +0800 Subject: [PATCH] fix amp docs --- docs/api/api_python/amp/mindspore.amp.all_finite.rst | 1 + mindspore/python/mindspore/amp.py | 8 +++++--- mindspore/python/mindspore/train/loss_scale_manager.py | 10 +++++----- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/docs/api/api_python/amp/mindspore.amp.all_finite.rst b/docs/api/api_python/amp/mindspore.amp.all_finite.rst index ce7e25820bc..38c000b9af9 100644 --- a/docs/api/api_python/amp/mindspore.amp.all_finite.rst +++ b/docs/api/api_python/amp/mindspore.amp.all_finite.rst @@ -12,6 +12,7 @@ mindspore.amp.all_finite 参数: - **inputs** (Union(tuple(Tensor), list(Tensor))) - 可迭代的Tensor。 + - **status** (Tensor) - 溢出检测时所需要的初始状态,仅在Ascend需要。默认值:None。 返回: Tensor,布尔类型的标量Tensor。 diff --git a/mindspore/python/mindspore/amp.py b/mindspore/python/mindspore/amp.py index 2a571a01147..ba9f73ed88a 100644 --- a/mindspore/python/mindspore/amp.py +++ b/mindspore/python/mindspore/amp.py @@ -63,7 +63,7 @@ def init_status(): Returns a Tensor indicating initialized status for overflow detection. Note: - Only Ascend need status to capture overflow status, you can alse call + Only Ascend need status to capture overflow status, you can also call this function on GPU or CPU, but the return value is useless. Returns: @@ -73,7 +73,7 @@ def init_status(): ``Ascend`` ``GPU`` ``CPU`` Examples: - >>> status = init_status() + >>> status = amp.init_status() """ if _ascend_target(): status = ops.NPUAllocFloatStatus()() @@ -98,6 +98,8 @@ def all_finite(inputs, status=None): Args: inputs (Union(tuple(Tensor), list(Tensor))): a iterable Tensor. + status (Tensor): the status Tensor for overflow detection, only required on + Ascend. Default: None. Returns: Tensor, a scalar Tensor and the dtype is bool. @@ -107,7 +109,7 @@ def all_finite(inputs, status=None): Examples: >>> x = (Tensor(np.array([np.log(-1), 1, np.log(0)])), Tensor(np.array([1.0])) - >>> output = all_finite(x) + >>> output = amp.all_finite(x) """ if _ascend_target(): if status is None: diff --git a/mindspore/python/mindspore/train/loss_scale_manager.py b/mindspore/python/mindspore/train/loss_scale_manager.py index d0330298b10..4021f7d443a 100644 --- a/mindspore/python/mindspore/train/loss_scale_manager.py +++ b/mindspore/python/mindspore/train/loss_scale_manager.py @@ -58,17 +58,17 @@ class FixedLossScaleManager(LossScaleManager): Examples: >>> import mindspore as ms - >>> from mindspore import nn + >>> from mindspore import amp, nn >>> >>> net = Net() >>> #1) Drop the parameter update if there is an overflow - >>> loss_scale_manager = ms.FixedLossScaleManager() + >>> loss_scale_manager = amp.FixedLossScaleManager() >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) >>> model = ms.Model(net, loss_scale_manager=loss_scale_manager, optimizer=optim) >>> >>> #2) Execute parameter update even if overflow occurs >>> loss_scale = 1024.0 - >>> loss_scale_manager = ms.FixedLossScaleManager(loss_scale, False) + >>> loss_scale_manager = amp.FixedLossScaleManager(loss_scale, False) >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9, loss_scale=loss_scale) >>> model = ms.Model(net, loss_scale_manager=loss_scale_manager, optimizer=optim) """ @@ -133,10 +133,10 @@ class DynamicLossScaleManager(LossScaleManager): Examples: >>> import mindspore as ms - >>> from mindspore import nn + >>> from mindspore import amp, nn >>> >>> net = Net() - >>> loss_scale_manager = ms.DynamicLossScaleManager() + >>> loss_scale_manager = amp.DynamicLossScaleManager() >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) >>> model = ms.Model(net, loss_scale_manager=loss_scale_manager, optimizer=optim) """