diff --git a/mindspore/nn/wrap/loss_scale.py b/mindspore/nn/wrap/loss_scale.py index 1ce3179273c..6a1f15a402c 100644 --- a/mindspore/nn/wrap/loss_scale.py +++ b/mindspore/nn/wrap/loss_scale.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================ """Loss scale cell for loss scale training.""" +import mindspore.context as context from mindspore.nn.wrap.grad_reducer import DistributedGradReducer from mindspore.train.parallel_utils import ParallelMode from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean @@ -34,6 +35,13 @@ reciprocal = P.Reciprocal() def tensor_grad_scale(scale, grad): return grad * F.cast(reciprocal(scale), F.dtype(grad)) +_grad_overflow = C.MultitypeFuncGraph("_grad_overflow") +grad_overflow = P.FloatStatus() + + +@_grad_overflow.register("Tensor") +def _tensor_grad_overflow(grad): + return grad_overflow(grad) class DynamicLossScaleUpdateCell(Cell): r""" @@ -197,9 +205,15 @@ class TrainOneStepWithLossScaleCell(Cell): self.optimizer = optimizer self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) self.hyper_map = C.HyperMap() - self.alloc_status = NPUAllocFloatStatus() - self.get_status = NPUGetFloatStatus() - self.clear_status = NPUClearFloatStatus() + if context.get_context("device_target") == "GPU": + self.gpu_target = True + self.float_status = P.FloatStatus() + self.addn = P.AddN() + else: + self.gpu_target = False + self.alloc_status = NPUAllocFloatStatus() + self.get_status = NPUGetFloatStatus() + self.clear_status = NPUClearFloatStatus() self.reduce_sum = ReduceSum(keep_dims=False) self.base = Tensor(1, mstype.float32) self.less_equal = LessEqual() @@ -224,10 +238,12 @@ class TrainOneStepWithLossScaleCell(Cell): def construct(self, data, label, sens=None): weights = self.weights loss = self.network(data, label) - # init overflow buffer - init = self.alloc_status() - # clear overflow buffer - self.clear_status(init) + init = False + if not self.gpu_target: + # init overflow buffer + init = self.alloc_status() + # clear overflow buffer + self.clear_status(init) if sens is None: scaling_sens = self.loss_scale else: @@ -238,9 +254,13 @@ class TrainOneStepWithLossScaleCell(Cell): # apply grad reducer on grads grads = self.grad_reducer(grads) # get the overflow buffer - self.get_status(init) - # sum overflow buffer elements, 0:not overflow , >0:overflow - flag_sum = self.reduce_sum(init, (0,)) + if not self.gpu_target: + self.get_status(init) + # sum overflow buffer elements, 0:not overflow , >0:overflow + flag_sum = self.reduce_sum(init, (0,)) + else: + flag_sum = self.hyper_map(F.partial(_grad_overflow), grads) + flag_sum = self.addn(flag_sum) if self.is_distributed: # sum overflow flag over devices flag_reduce = self.allreduce(flag_sum) diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 846be05c4d7..89a5ea02495 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -44,7 +44,7 @@ from .math_ops import (Abs, ACos, AddN, AssignAdd, AssignSub, Atan2, BatchMatMul LogicalNot, LogicalOr, MatMul, Maximum, Minimum, Mul, Neg, NMSWithMask, NotEqual, NPUAllocFloatStatus, NPUClearFloatStatus, - NPUGetFloatStatus, Pow, RealDiv, + NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus, Reciprocal, CumSum, Sin, Sqrt, Rsqrt, Square, Sub, TensorAdd, Sign, Round) @@ -154,6 +154,10 @@ __all__ = [ 'NPUAllocFloatStatus', 'NPUGetFloatStatus', 'NPUClearFloatStatus', + 'IsNan', + 'IsFinite', + 'IsInf', + 'FloatStatus', 'Reciprocal', 'SmoothL1Loss', 'ReduceAll', diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index 175b72560f6..127d3c513c1 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -1541,6 +1541,94 @@ class LogicalOr(_LogicBinaryOp): def infer_dtype(self, x_dtype, y_dtype): return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,), self.prim_name()) +class IsNan(PrimitiveWithInfer): + """ + Judging which elements are nan for each position + + Inputs: + - **input_x** (Tensor) - The input tensor. + + Outputs: + Tensor, has the same shape of input, and the dtype is bool. + """ + + @prim_attr_register + def __init__(self): + """init IsNan""" + self.init_prim_io_names(inputs=['x'], outputs=['output']) + + def infer_shape(self, x_shape): + return x_shape + + def infer_dtype(self, x_dtype): + return mstype.bool_ + +class IsInf(PrimitiveWithInfer): + """ + Judging which elements are inf or -inf for each position + + Inputs: + - **input_x** (Tensor) - The input tensor. + + Outputs: + Tensor, has the same shape of input, and the dtype is bool. + """ + + @prim_attr_register + def __init__(self): + """init IsInf""" + self.init_prim_io_names(inputs=['x'], outputs=['output']) + + def infer_shape(self, x_shape): + return x_shape + + def infer_dtype(self, x_dtype): + return mstype.bool_ + +class IsFinite(PrimitiveWithInfer): + """ + Judging which elements are finite for each position + + Inputs: + - **input_x** (Tensor) - The input tensor. + + Outputs: + Tensor, has the same shape of input, and the dtype is bool. + """ + + @prim_attr_register + def __init__(self): + """init IsFinite""" + self.init_prim_io_names(inputs=['x'], outputs=['output']) + + def infer_shape(self, x_shape): + return x_shape + + def infer_dtype(self, x_dtype): + return mstype.bool_ + +class FloatStatus(PrimitiveWithInfer): + """ + Determine if the elements contains nan, inf or -inf. `0` for normal, `1` for overflow. + + Inputs: + - **input_x** (Tensor) - The input tensor. + + Outputs: + Tensor, has the shape of `(1,)`, and has the same dtype of input `mindspore.dtype.float32` or + `mindspore.dtype.float16`. + """ + + @prim_attr_register + def __init__(self): + """init FloatStatus""" + self.init_prim_io_names(inputs=['x'], outputs=['output']) + + def infer_shape(self, x_shape): + return [1] + + def infer_dtype(self, x_dtype): + return x_dtype class NPUAllocFloatStatus(PrimitiveWithInfer): """