forked from mindspore-Ecosystem/mindspore
edit loss_scale for gpu
This commit is contained in:
parent
2c3c1577b1
commit
2ff6f0de46
|
@ -25,6 +25,7 @@ from ...ops import operations as P
|
|||
from ...ops.operations import NPUGetFloatStatus, NPUAllocFloatStatus, NPUClearFloatStatus, ReduceSum, LessEqual, \
|
||||
ControlDepend
|
||||
from ...common import dtype as mstype
|
||||
import mindspore.context as context
|
||||
|
||||
_grad_scale = C.MultitypeFuncGraph("grad_scale")
|
||||
reciprocal = P.Reciprocal()
|
||||
|
@ -34,6 +35,12 @@ reciprocal = P.Reciprocal()
|
|||
def tensor_grad_scale(scale, grad):
|
||||
return grad * F.cast(reciprocal(scale), F.dtype(grad))
|
||||
|
||||
_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
|
||||
grad_overflow = P.FloatStatus()
|
||||
|
||||
@_grad_overflow.register("Tensor")
|
||||
def _tensor_grad_overflow(grad):
|
||||
return grad_overflow(grad)
|
||||
|
||||
class DynamicLossScaleUpdateCell(Cell):
|
||||
r"""
|
||||
|
@ -195,6 +202,12 @@ class TrainOneStepWithLossScaleCell(Cell):
|
|||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
|
||||
self.hyper_map = C.HyperMap()
|
||||
if context.get_context("device_target") == "GPU":
|
||||
self.gpu_target = True
|
||||
self.float_status = P.FloatStatus()
|
||||
self.addn = P.AddN()
|
||||
else:
|
||||
self.gpu_target = False
|
||||
self.alloc_status = NPUAllocFloatStatus()
|
||||
self.get_status = NPUGetFloatStatus()
|
||||
self.clear_status = NPUClearFloatStatus()
|
||||
|
@ -222,6 +235,7 @@ class TrainOneStepWithLossScaleCell(Cell):
|
|||
def construct(self, data, label, sens=None):
|
||||
weights = self.weights
|
||||
loss = self.network(data, label)
|
||||
if not self.gpu_target:
|
||||
# init overflow buffer
|
||||
init = self.alloc_status()
|
||||
# clear overflow buffer
|
||||
|
@ -235,10 +249,14 @@ class TrainOneStepWithLossScaleCell(Cell):
|
|||
if self.reducer_flag:
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
if not self.gpu_target:
|
||||
# get the overflow buffer
|
||||
self.get_status(init)
|
||||
# sum overflow buffer elements, 0:not overflow , >0:overflow
|
||||
flag_sum = self.reduce_sum(init, (0,))
|
||||
else:
|
||||
flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
|
||||
flag_sum = self.addn(flag_sum)
|
||||
if self.is_distributed:
|
||||
# sum overflow flag over devices
|
||||
flag_reduce = self.allreduce(flag_sum)
|
||||
|
|
|
@ -44,7 +44,7 @@ from .math_ops import (Abs, ACos, AddN, AssignAdd, AssignSub, Atan2, BatchMatMul
|
|||
LogicalNot, LogicalOr, MatMul, Maximum,
|
||||
Minimum, Mul, Neg, NMSWithMask, NotEqual,
|
||||
NPUAllocFloatStatus, NPUClearFloatStatus,
|
||||
NPUGetFloatStatus, Pow, RealDiv,
|
||||
NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus,
|
||||
Reciprocal, CumSum,
|
||||
Sin, Sqrt, Rsqrt,
|
||||
Square, Sub, TensorAdd, Sign, Round)
|
||||
|
@ -151,6 +151,10 @@ __all__ = [
|
|||
'Neg',
|
||||
'Slice',
|
||||
'DType',
|
||||
'IsNan',
|
||||
'IsInf',
|
||||
'IsFinite',
|
||||
'FloatStatus',
|
||||
'NPUAllocFloatStatus',
|
||||
'NPUGetFloatStatus',
|
||||
'NPUClearFloatStatus',
|
||||
|
|
|
@ -1557,6 +1557,89 @@ class LogicalOr(_LogicBinaryOp):
|
|||
def infer_dtype(self, x_dtype, y_dtype):
|
||||
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,))
|
||||
|
||||
class IsNan(PrimitiveWithInfer):
|
||||
"""
|
||||
Judging which elements are nan for each position
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The input tensor.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same shape of input.
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init IsNan"""
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
return mstype.bool_
|
||||
|
||||
class IsInf(PrimitiveWithInfer):
|
||||
"""
|
||||
Judging which elements are inf or -inf for each position
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The input tensor.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same shape of input.
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init IsInf"""
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
return mstype.bool_
|
||||
|
||||
class IsFinite(PrimitiveWithInfer):
|
||||
"""
|
||||
Judging which elements are finite for each position
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The input tensor.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same shape of input.
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init IsFinite"""
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
return mstype.bool_
|
||||
|
||||
class FloatStatus(PrimitiveWithInfer):
|
||||
"""
|
||||
Determine if the elements contains nan, inf or -inf
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The input tensor.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the shape of `(1,)`.
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init FloatStatus"""
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
return [1]
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
return x_dtype
|
||||
|
||||
class NPUAllocFloatStatus(PrimitiveWithInfer):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue