forked from mindspore-Ecosystem/mindspore
edit loss_scale for gpu
This commit is contained in:
parent
2c3c1577b1
commit
2ff6f0de46
|
@ -25,6 +25,7 @@ from ...ops import operations as P
|
||||||
from ...ops.operations import NPUGetFloatStatus, NPUAllocFloatStatus, NPUClearFloatStatus, ReduceSum, LessEqual, \
|
from ...ops.operations import NPUGetFloatStatus, NPUAllocFloatStatus, NPUClearFloatStatus, ReduceSum, LessEqual, \
|
||||||
ControlDepend
|
ControlDepend
|
||||||
from ...common import dtype as mstype
|
from ...common import dtype as mstype
|
||||||
|
import mindspore.context as context
|
||||||
|
|
||||||
_grad_scale = C.MultitypeFuncGraph("grad_scale")
|
_grad_scale = C.MultitypeFuncGraph("grad_scale")
|
||||||
reciprocal = P.Reciprocal()
|
reciprocal = P.Reciprocal()
|
||||||
|
@ -34,6 +35,12 @@ reciprocal = P.Reciprocal()
|
||||||
def tensor_grad_scale(scale, grad):
|
def tensor_grad_scale(scale, grad):
|
||||||
return grad * F.cast(reciprocal(scale), F.dtype(grad))
|
return grad * F.cast(reciprocal(scale), F.dtype(grad))
|
||||||
|
|
||||||
|
_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
|
||||||
|
grad_overflow = P.FloatStatus()
|
||||||
|
|
||||||
|
@_grad_overflow.register("Tensor")
|
||||||
|
def _tensor_grad_overflow(grad):
|
||||||
|
return grad_overflow(grad)
|
||||||
|
|
||||||
class DynamicLossScaleUpdateCell(Cell):
|
class DynamicLossScaleUpdateCell(Cell):
|
||||||
r"""
|
r"""
|
||||||
|
@ -195,6 +202,12 @@ class TrainOneStepWithLossScaleCell(Cell):
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
|
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
|
||||||
self.hyper_map = C.HyperMap()
|
self.hyper_map = C.HyperMap()
|
||||||
|
if context.get_context("device_target") == "GPU":
|
||||||
|
self.gpu_target = True
|
||||||
|
self.float_status = P.FloatStatus()
|
||||||
|
self.addn = P.AddN()
|
||||||
|
else:
|
||||||
|
self.gpu_target = False
|
||||||
self.alloc_status = NPUAllocFloatStatus()
|
self.alloc_status = NPUAllocFloatStatus()
|
||||||
self.get_status = NPUGetFloatStatus()
|
self.get_status = NPUGetFloatStatus()
|
||||||
self.clear_status = NPUClearFloatStatus()
|
self.clear_status = NPUClearFloatStatus()
|
||||||
|
@ -222,6 +235,7 @@ class TrainOneStepWithLossScaleCell(Cell):
|
||||||
def construct(self, data, label, sens=None):
|
def construct(self, data, label, sens=None):
|
||||||
weights = self.weights
|
weights = self.weights
|
||||||
loss = self.network(data, label)
|
loss = self.network(data, label)
|
||||||
|
if not self.gpu_target:
|
||||||
# init overflow buffer
|
# init overflow buffer
|
||||||
init = self.alloc_status()
|
init = self.alloc_status()
|
||||||
# clear overflow buffer
|
# clear overflow buffer
|
||||||
|
@ -235,10 +249,14 @@ class TrainOneStepWithLossScaleCell(Cell):
|
||||||
if self.reducer_flag:
|
if self.reducer_flag:
|
||||||
# apply grad reducer on grads
|
# apply grad reducer on grads
|
||||||
grads = self.grad_reducer(grads)
|
grads = self.grad_reducer(grads)
|
||||||
|
if not self.gpu_target:
|
||||||
# get the overflow buffer
|
# get the overflow buffer
|
||||||
self.get_status(init)
|
self.get_status(init)
|
||||||
# sum overflow buffer elements, 0:not overflow , >0:overflow
|
# sum overflow buffer elements, 0:not overflow , >0:overflow
|
||||||
flag_sum = self.reduce_sum(init, (0,))
|
flag_sum = self.reduce_sum(init, (0,))
|
||||||
|
else:
|
||||||
|
flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
|
||||||
|
flag_sum = self.addn(flag_sum)
|
||||||
if self.is_distributed:
|
if self.is_distributed:
|
||||||
# sum overflow flag over devices
|
# sum overflow flag over devices
|
||||||
flag_reduce = self.allreduce(flag_sum)
|
flag_reduce = self.allreduce(flag_sum)
|
||||||
|
|
|
@ -44,7 +44,7 @@ from .math_ops import (Abs, ACos, AddN, AssignAdd, AssignSub, Atan2, BatchMatMul
|
||||||
LogicalNot, LogicalOr, MatMul, Maximum,
|
LogicalNot, LogicalOr, MatMul, Maximum,
|
||||||
Minimum, Mul, Neg, NMSWithMask, NotEqual,
|
Minimum, Mul, Neg, NMSWithMask, NotEqual,
|
||||||
NPUAllocFloatStatus, NPUClearFloatStatus,
|
NPUAllocFloatStatus, NPUClearFloatStatus,
|
||||||
NPUGetFloatStatus, Pow, RealDiv,
|
NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus,
|
||||||
Reciprocal, CumSum,
|
Reciprocal, CumSum,
|
||||||
Sin, Sqrt, Rsqrt,
|
Sin, Sqrt, Rsqrt,
|
||||||
Square, Sub, TensorAdd, Sign, Round)
|
Square, Sub, TensorAdd, Sign, Round)
|
||||||
|
@ -151,6 +151,10 @@ __all__ = [
|
||||||
'Neg',
|
'Neg',
|
||||||
'Slice',
|
'Slice',
|
||||||
'DType',
|
'DType',
|
||||||
|
'IsNan',
|
||||||
|
'IsInf',
|
||||||
|
'IsFinite',
|
||||||
|
'FloatStatus',
|
||||||
'NPUAllocFloatStatus',
|
'NPUAllocFloatStatus',
|
||||||
'NPUGetFloatStatus',
|
'NPUGetFloatStatus',
|
||||||
'NPUClearFloatStatus',
|
'NPUClearFloatStatus',
|
||||||
|
|
|
@ -1557,6 +1557,89 @@ class LogicalOr(_LogicBinaryOp):
|
||||||
def infer_dtype(self, x_dtype, y_dtype):
|
def infer_dtype(self, x_dtype, y_dtype):
|
||||||
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,))
|
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,))
|
||||||
|
|
||||||
|
class IsNan(PrimitiveWithInfer):
|
||||||
|
"""
|
||||||
|
Judging which elements are nan for each position
|
||||||
|
Inputs:
|
||||||
|
- **input_x** (Tensor) - The input tensor.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, has the same shape of input.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""init IsNan"""
|
||||||
|
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||||
|
|
||||||
|
def infer_shape(self, x_shape):
|
||||||
|
return x_shape
|
||||||
|
|
||||||
|
def infer_dtype(self, x_dtype):
|
||||||
|
return mstype.bool_
|
||||||
|
|
||||||
|
class IsInf(PrimitiveWithInfer):
|
||||||
|
"""
|
||||||
|
Judging which elements are inf or -inf for each position
|
||||||
|
Inputs:
|
||||||
|
- **input_x** (Tensor) - The input tensor.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, has the same shape of input.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""init IsInf"""
|
||||||
|
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||||
|
|
||||||
|
def infer_shape(self, x_shape):
|
||||||
|
return x_shape
|
||||||
|
|
||||||
|
def infer_dtype(self, x_dtype):
|
||||||
|
return mstype.bool_
|
||||||
|
|
||||||
|
class IsFinite(PrimitiveWithInfer):
|
||||||
|
"""
|
||||||
|
Judging which elements are finite for each position
|
||||||
|
Inputs:
|
||||||
|
- **input_x** (Tensor) - The input tensor.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, has the same shape of input.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""init IsFinite"""
|
||||||
|
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||||
|
|
||||||
|
def infer_shape(self, x_shape):
|
||||||
|
return x_shape
|
||||||
|
|
||||||
|
def infer_dtype(self, x_dtype):
|
||||||
|
return mstype.bool_
|
||||||
|
|
||||||
|
class FloatStatus(PrimitiveWithInfer):
|
||||||
|
"""
|
||||||
|
Determine if the elements contains nan, inf or -inf
|
||||||
|
Inputs:
|
||||||
|
- **input_x** (Tensor) - The input tensor.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, has the shape of `(1,)`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""init FloatStatus"""
|
||||||
|
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||||
|
|
||||||
|
def infer_shape(self, x_shape):
|
||||||
|
return [1]
|
||||||
|
|
||||||
|
def infer_dtype(self, x_dtype):
|
||||||
|
return x_dtype
|
||||||
|
|
||||||
class NPUAllocFloatStatus(PrimitiveWithInfer):
|
class NPUAllocFloatStatus(PrimitiveWithInfer):
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue