forked from mindspore-Ecosystem/mindspore
!161 Edit loss_scale to fit GPU
Merge pull request !161 from VectorSL/master
This commit is contained in:
commit
0d838c7c9b
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Loss scale cell for loss scale training."""
|
||||
import mindspore.context as context
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
from mindspore.train.parallel_utils import ParallelMode
|
||||
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean
|
||||
|
@ -34,6 +35,13 @@ reciprocal = P.Reciprocal()
|
|||
def tensor_grad_scale(scale, grad):
|
||||
return grad * F.cast(reciprocal(scale), F.dtype(grad))
|
||||
|
||||
_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
|
||||
grad_overflow = P.FloatStatus()
|
||||
|
||||
|
||||
@_grad_overflow.register("Tensor")
|
||||
def _tensor_grad_overflow(grad):
|
||||
return grad_overflow(grad)
|
||||
|
||||
class DynamicLossScaleUpdateCell(Cell):
|
||||
r"""
|
||||
|
@ -197,9 +205,15 @@ class TrainOneStepWithLossScaleCell(Cell):
|
|||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.alloc_status = NPUAllocFloatStatus()
|
||||
self.get_status = NPUGetFloatStatus()
|
||||
self.clear_status = NPUClearFloatStatus()
|
||||
if context.get_context("device_target") == "GPU":
|
||||
self.gpu_target = True
|
||||
self.float_status = P.FloatStatus()
|
||||
self.addn = P.AddN()
|
||||
else:
|
||||
self.gpu_target = False
|
||||
self.alloc_status = NPUAllocFloatStatus()
|
||||
self.get_status = NPUGetFloatStatus()
|
||||
self.clear_status = NPUClearFloatStatus()
|
||||
self.reduce_sum = ReduceSum(keep_dims=False)
|
||||
self.base = Tensor(1, mstype.float32)
|
||||
self.less_equal = LessEqual()
|
||||
|
@ -224,10 +238,12 @@ class TrainOneStepWithLossScaleCell(Cell):
|
|||
def construct(self, data, label, sens=None):
|
||||
weights = self.weights
|
||||
loss = self.network(data, label)
|
||||
# init overflow buffer
|
||||
init = self.alloc_status()
|
||||
# clear overflow buffer
|
||||
self.clear_status(init)
|
||||
init = False
|
||||
if not self.gpu_target:
|
||||
# init overflow buffer
|
||||
init = self.alloc_status()
|
||||
# clear overflow buffer
|
||||
self.clear_status(init)
|
||||
if sens is None:
|
||||
scaling_sens = self.loss_scale
|
||||
else:
|
||||
|
@ -238,9 +254,13 @@ class TrainOneStepWithLossScaleCell(Cell):
|
|||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
# get the overflow buffer
|
||||
self.get_status(init)
|
||||
# sum overflow buffer elements, 0:not overflow , >0:overflow
|
||||
flag_sum = self.reduce_sum(init, (0,))
|
||||
if not self.gpu_target:
|
||||
self.get_status(init)
|
||||
# sum overflow buffer elements, 0:not overflow , >0:overflow
|
||||
flag_sum = self.reduce_sum(init, (0,))
|
||||
else:
|
||||
flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
|
||||
flag_sum = self.addn(flag_sum)
|
||||
if self.is_distributed:
|
||||
# sum overflow flag over devices
|
||||
flag_reduce = self.allreduce(flag_sum)
|
||||
|
|
|
@ -44,7 +44,7 @@ from .math_ops import (Abs, ACos, AddN, AssignAdd, AssignSub, Atan2, BatchMatMul
|
|||
LogicalNot, LogicalOr, MatMul, Maximum,
|
||||
Minimum, Mul, Neg, NMSWithMask, NotEqual,
|
||||
NPUAllocFloatStatus, NPUClearFloatStatus,
|
||||
NPUGetFloatStatus, Pow, RealDiv,
|
||||
NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus,
|
||||
Reciprocal, CumSum,
|
||||
Sin, Sqrt, Rsqrt,
|
||||
Square, Sub, TensorAdd, Sign, Round)
|
||||
|
@ -154,6 +154,10 @@ __all__ = [
|
|||
'NPUAllocFloatStatus',
|
||||
'NPUGetFloatStatus',
|
||||
'NPUClearFloatStatus',
|
||||
'IsNan',
|
||||
'IsFinite',
|
||||
'IsInf',
|
||||
'FloatStatus',
|
||||
'Reciprocal',
|
||||
'SmoothL1Loss',
|
||||
'ReduceAll',
|
||||
|
|
|
@ -1541,6 +1541,94 @@ class LogicalOr(_LogicBinaryOp):
|
|||
def infer_dtype(self, x_dtype, y_dtype):
|
||||
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,), self.prim_name())
|
||||
|
||||
class IsNan(PrimitiveWithInfer):
|
||||
"""
|
||||
Judging which elements are nan for each position
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The input tensor.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same shape of input, and the dtype is bool.
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init IsNan"""
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
return mstype.bool_
|
||||
|
||||
class IsInf(PrimitiveWithInfer):
|
||||
"""
|
||||
Judging which elements are inf or -inf for each position
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The input tensor.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same shape of input, and the dtype is bool.
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init IsInf"""
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
return mstype.bool_
|
||||
|
||||
class IsFinite(PrimitiveWithInfer):
|
||||
"""
|
||||
Judging which elements are finite for each position
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The input tensor.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same shape of input, and the dtype is bool.
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init IsFinite"""
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
return mstype.bool_
|
||||
|
||||
class FloatStatus(PrimitiveWithInfer):
|
||||
"""
|
||||
Determine if the elements contains nan, inf or -inf. `0` for normal, `1` for overflow.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The input tensor.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the shape of `(1,)`, and has the same dtype of input `mindspore.dtype.float32` or
|
||||
`mindspore.dtype.float16`.
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init FloatStatus"""
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
return [1]
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
return x_dtype
|
||||
|
||||
class NPUAllocFloatStatus(PrimitiveWithInfer):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue