forked from mindspore-Ecosystem/mindspore
!47908 add init_status for amp
Merge pull request !47908 from 吕昱峰(Nate.River)/r2.0.0-alpha
This commit is contained in:
commit
2174ae676c
|
@ -1,7 +1,7 @@
|
|||
mindspore.amp.all_finite
|
||||
========================
|
||||
|
||||
.. py:function:: mindspore.amp.all_finite(inputs)
|
||||
.. py:function:: mindspore.amp.all_finite(inputs, status=None)
|
||||
|
||||
检查inputs是否是有效值(无溢出)。
|
||||
|
||||
|
@ -12,6 +12,7 @@ mindspore.amp.all_finite
|
|||
|
||||
参数:
|
||||
- **inputs** (Union(tuple(Tensor), list(Tensor))) - 可迭代的Tensor。
|
||||
- **status** (Tensor) - 溢出检测时所需要的初始状态,仅在Ascend需要。默认值:None。
|
||||
|
||||
返回:
|
||||
Tensor,布尔类型的标量Tensor。
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
mindspore.amp.init_status
|
||||
===========================
|
||||
|
||||
.. py:function:: mindspore.amp.init_status()
|
||||
|
||||
初始化溢出状态检测变量。
|
||||
|
||||
.. note::
|
||||
该接口仅在Ascend后端有效,在GPU、CPU上调用的返回值没有作用。
|
||||
|
||||
返回:
|
||||
Tensor,shape为 (8,) 。
|
|
@ -1,20 +1,7 @@
|
|||
mindspore.amp
|
||||
================
|
||||
|
||||
Cell管理
|
||||
-----------
|
||||
|
||||
.. mscnautosummary::
|
||||
:toctree: amp
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
mindspore.amp.DynamicLossScaleManager
|
||||
mindspore.amp.LossScaleManager
|
||||
mindspore.amp.FixedLossScaleManager
|
||||
mindspore.amp.build_train_network
|
||||
|
||||
函数式
|
||||
梯度缩放
|
||||
-----------
|
||||
|
||||
.. mscnautosummary::
|
||||
|
@ -22,8 +9,31 @@ Cell管理
|
|||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
mindspore.amp.LossScaler
|
||||
mindspore.amp.DynamicLossScaler
|
||||
mindspore.amp.StaticLossScaler
|
||||
mindspore.amp.LossScaler
|
||||
mindspore.amp.LossScaleManager
|
||||
mindspore.amp.DynamicLossScaleManager
|
||||
mindspore.amp.FixedLossScaleManager
|
||||
|
||||
数据类型自动转换
|
||||
----------------
|
||||
|
||||
.. mscnautosummary::
|
||||
:toctree: amp
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
mindspore.amp.auto_mixed_precision
|
||||
mindspore.amp.build_train_network
|
||||
|
||||
溢出检测
|
||||
-----------
|
||||
|
||||
.. mscnautosummary::
|
||||
:toctree: amp
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
mindspore.amp.init_status
|
||||
mindspore.amp.all_finite
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
mindspore.amp
|
||||
================
|
||||
|
||||
Cell Management
|
||||
Loss Scale
|
||||
----------------
|
||||
|
||||
.. autosummary::
|
||||
|
@ -9,12 +9,14 @@ Cell Management
|
|||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
mindspore.amp.DynamicLossScaleManager
|
||||
mindspore.amp.LossScaler
|
||||
mindspore.amp.DynamicLossScaler
|
||||
mindspore.amp.StaticLossScaler
|
||||
mindspore.amp.LossScaleManager
|
||||
mindspore.amp.DynamicLossScaleManager
|
||||
mindspore.amp.FixedLossScaleManager
|
||||
mindspore.amp.build_train_network
|
||||
|
||||
Functional Paradigm
|
||||
Dtype Autocast
|
||||
--------------------
|
||||
|
||||
.. autosummary::
|
||||
|
@ -22,8 +24,16 @@ Functional Paradigm
|
|||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
mindspore.amp.DynamicLossScaler
|
||||
mindspore.amp.StaticLossScaler
|
||||
mindspore.amp.LossScaler
|
||||
mindspore.amp.auto_mixed_precision
|
||||
mindspore.amp.build_train_network
|
||||
|
||||
Overflow Detection
|
||||
--------------------
|
||||
|
||||
.. autosummary::
|
||||
:toctree: amp
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
mindspore.amp.init_status
|
||||
mindspore.amp.all_finite
|
||||
|
|
|
@ -21,31 +21,28 @@ from ._checkparam import Validator as validator
|
|||
from .common import dtype as mstype
|
||||
from . import context
|
||||
from . import ops
|
||||
from .ops import constexpr
|
||||
from .common.api import jit_class
|
||||
from .common.parameter import Parameter
|
||||
from .common.tensor import Tensor
|
||||
from .train.loss_scale_manager import DynamicLossScaleManager, LossScaleManager, FixedLossScaleManager
|
||||
from .train.amp import build_train_network, auto_mixed_precision
|
||||
|
||||
_ascend_target = context.get_context("device_target") == "Ascend"
|
||||
_gpu_target = context.get_context("device_target") == "GPU"
|
||||
|
||||
_gpu_float_status = ops.FloatStatus()
|
||||
|
||||
_npu_alloc_float_status = ops.NPUAllocFloatStatus()
|
||||
_npu_clear_float_status = ops.NPUClearFloatStatus()
|
||||
_npu_get_float_status = ops.NPUGetFloatStatus()
|
||||
|
||||
if _ascend_target:
|
||||
_status = _npu_alloc_float_status()
|
||||
_ = _npu_clear_float_status(_status)
|
||||
else:
|
||||
_status = None
|
||||
|
||||
_hypermap = ops.HyperMap()
|
||||
_partial = ops.Partial()
|
||||
|
||||
|
||||
@constexpr
|
||||
def _ascend_target():
|
||||
return context.get_context("device_target") == "Ascend"
|
||||
|
||||
|
||||
@constexpr
|
||||
def _gpu_target():
|
||||
return context.get_context("device_target") == "GPU"
|
||||
|
||||
|
||||
def _grad_unscale(scale, grad):
|
||||
return grad * ops.Reciprocal()(scale).astype(grad.dtype)
|
||||
|
||||
|
@ -55,13 +52,40 @@ def _grad_scale(scale, grad):
|
|||
|
||||
|
||||
def _is_finite(inputs):
|
||||
if _gpu_target:
|
||||
return _gpu_float_status(inputs)[0] == 0
|
||||
if _gpu_target():
|
||||
return ops.FloatStatus()(inputs)[0] == 0
|
||||
status = ops.isfinite(inputs)
|
||||
return status.all()
|
||||
|
||||
|
||||
def all_finite(inputs):
|
||||
def init_status():
|
||||
r"""
|
||||
Returns a Tensor indicating initialized status for overflow detection.
|
||||
|
||||
Note:
|
||||
Only Ascend need status to capture overflow status, you can also call
|
||||
this function on GPU or CPU, but the return value is useless.
|
||||
|
||||
Returns:
|
||||
Tensor, has the shape of `(8,)`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> status = amp.init_status()
|
||||
"""
|
||||
if _ascend_target():
|
||||
status = ops.NPUAllocFloatStatus()()
|
||||
clear_status = ops.NPUClearFloatStatus()(status)
|
||||
status = ops.depend(status, clear_status)
|
||||
else:
|
||||
status = Tensor([0, 0, 0, 0, 0, 0, 0, 0], mstype.float32)
|
||||
|
||||
return status
|
||||
|
||||
|
||||
def all_finite(inputs, status=None):
|
||||
r"""
|
||||
Returns a scalar Tensor indicating whether the inputs are finite.
|
||||
|
||||
|
@ -74,6 +98,8 @@ def all_finite(inputs):
|
|||
|
||||
Args:
|
||||
inputs (Union(tuple(Tensor), list(Tensor))): a iterable Tensor.
|
||||
status (Tensor): the status Tensor for overflow detection, only required on
|
||||
Ascend. Default: None.
|
||||
|
||||
Returns:
|
||||
Tensor, a scalar Tensor and the dtype is bool.
|
||||
|
@ -83,14 +109,16 @@ def all_finite(inputs):
|
|||
|
||||
Examples:
|
||||
>>> x = (Tensor(np.array([np.log(-1), 1, np.log(0)])), Tensor(np.array([1.0]))
|
||||
>>> output = all_finite(x)
|
||||
>>> output = amp.all_finite(x)
|
||||
"""
|
||||
if _ascend_target:
|
||||
status = ops.depend(_status, inputs)
|
||||
get_status = _npu_get_float_status(status)
|
||||
if _ascend_target():
|
||||
if status is None:
|
||||
raise ValueError("The status must be initialized on Ascend, but get 'None'.")
|
||||
status = ops.depend(status, inputs)
|
||||
get_status = ops.NPUGetFloatStatus()(status)
|
||||
status = ops.depend(status, get_status)
|
||||
status_finite = status.sum() == 0
|
||||
_ = _npu_clear_float_status(status)
|
||||
_ = ops.NPUClearFloatStatus()(status)
|
||||
return status_finite
|
||||
outputs = _hypermap(_partial(_is_finite), inputs)
|
||||
return ops.stack(outputs).all()
|
||||
|
@ -299,5 +327,5 @@ class DynamicLossScaler(LossScaler):
|
|||
__all__ = [
|
||||
"DynamicLossScaleManager", "LossScaleManager", "FixedLossScaleManager",
|
||||
"build_train_network", "DynamicLossScaler", "StaticLossScaler", "LossScaler",
|
||||
"auto_mixed_precision", "all_finite"
|
||||
"auto_mixed_precision", "init_status", "all_finite"
|
||||
]
|
||||
|
|
|
@ -60,20 +60,22 @@ def test_dynamic_loss_scaler(mode):
|
|||
Expectation: the `scale_value` can be adjusted correctly.
|
||||
"""
|
||||
context.set_context(mode=mode)
|
||||
status = amp.init_status()
|
||||
loss_scaler = amp.DynamicLossScaler(scale_value=2**10, scale_factor=2, scale_window=50)
|
||||
|
||||
grads = (Tensor(np.array([0.5, 1.0]), mindspore.float16),
|
||||
Tensor(np.array([0.2]), mindspore.float16))
|
||||
unscaled_grads = loss_scaler.unscale(grads)
|
||||
grads_finite = amp.all_finite(unscaled_grads)
|
||||
grads_finite = amp.all_finite(unscaled_grads, status)
|
||||
loss_scaler.counter = Parameter(Tensor(49, dtype=mstype.int32))
|
||||
loss_scaler.adjust(grads_finite)
|
||||
assert loss_scaler.scale_value.asnumpy() == np.array(2048.)
|
||||
|
||||
status = amp.init_status()
|
||||
grads = (Tensor(np.array([2., 1.0]), mindspore.float16),
|
||||
Tensor(np.array([0.2]), mindspore.float16))
|
||||
unscaled_grads = loss_scaler.unscale(grads)
|
||||
grads_finite = amp.all_finite(unscaled_grads)
|
||||
grads_finite = amp.all_finite(unscaled_grads, status)
|
||||
loss_scaler.scale_value = Parameter(Tensor(2**10, dtype=mstype.float32))
|
||||
loss_scaler.adjust(grads_finite)
|
||||
assert loss_scaler.scale_value.asnumpy() == np.array(1024.)
|
||||
|
|
Loading…
Reference in New Issue