!23072 Clean side_effect_flag

Merge pull request !23072 from Margaret_wangrui/clean_side_effect_flag
This commit is contained in:
i-robot 2021-09-09 02:47:34 +00:00 committed by Gitee
commit 90effa5821
2 changed files with 31 additions and 16 deletions

View File

@ -4087,7 +4087,6 @@ class NPUAllocFloatStatus(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
"""Initialize NPUAllocFloatStatus""" """Initialize NPUAllocFloatStatus"""
self.add_prim_attr("_side_effect_flag", True)
def infer_shape(self): def infer_shape(self):
return [8] return [8]
@ -4103,6 +4102,9 @@ class NPUGetFloatStatus(PrimitiveWithInfer):
The flag is a tensor whose shape is `(8,)` and data type is `mindspore.dtype.float32`. The flag is a tensor whose shape is `(8,)` and data type is `mindspore.dtype.float32`.
If the sum of the flag equals to 0, there is no overflow happened. If the sum of the flag is bigger than 0, there If the sum of the flag equals to 0, there is no overflow happened. If the sum of the flag is bigger than 0, there
is overflow happened. is overflow happened.
In addition, there are strict sequencing requirements for use, i.e., before using the NPUGetFloatStatus operator,
need to ensure that the NPUClearFlotStatus and your compute has been executed.
We use Depend to ensure the execution order.
Inputs: Inputs:
- **x** (Tensor) - The output tensor of `NPUAllocFloatStatus`. - **x** (Tensor) - The output tensor of `NPUAllocFloatStatus`.
@ -4120,10 +4122,17 @@ class NPUGetFloatStatus(PrimitiveWithInfer):
``Ascend`` ``Ascend``
Examples: Examples:
>>> alloc_status = ops.NPUAllocFloatStatus() >>> self.alloc_status = ops.NPUAllocFloatStatus()
>>> get_status = ops.NPUGetFloatStatus() >>> self.get_status = ops.NPUGetFloatStatus()
>>> init = alloc_status() >>> self.clear_status = ops.NPUClearFloatStatus()
>>> get_status(init) >>> init = self.alloc_status()
>>> init = F.Depend(init, input) # Ensure clear_status after input
>>> clear_status = self.clear_status(init)
>>> input = F.Depend(input, clear_status) # Ensure your compute after clear_status
>>> output = Compute(input)
>>> init = F.Depend(init, output)
>>> flag = self.get_status(init) # Ensure get_status after your compute
>>> self.clear_status(init)
>>> print(init) >>> print(init)
[0. 0. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 0. 0. 0.]
""" """
@ -4131,7 +4140,6 @@ class NPUGetFloatStatus(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
"""Initialize NPUGetFloatStatus""" """Initialize NPUGetFloatStatus"""
self.add_prim_attr("_side_effect_flag", True)
def infer_shape(self, x_shape): def infer_shape(self, x_shape):
cls_name = self.name cls_name = self.name
@ -4151,6 +4159,9 @@ class NPUClearFloatStatus(PrimitiveWithInfer):
Note: Note:
The flag is in the register on the `Ascend` device. It will be reset and can not be reused again after the The flag is in the register on the `Ascend` device. It will be reset and can not be reused again after the
`NPUClearFloatStatus` is called. `NPUClearFloatStatus` is called.
In addition, there are strict sequencing requirements for use, i.e., before using the NPUGetFloatStatus
operator, need to ensure that the NPUClearFlotStatus and your compute has been executed.
We use Depend to ensure the execution order.
Examples: see `NPUGetFloatStatus`. Examples: see `NPUGetFloatStatus`.
@ -4165,12 +4176,17 @@ class NPUClearFloatStatus(PrimitiveWithInfer):
``Ascend`` ``Ascend``
Examples: Examples:
>>> alloc_status = ops.NPUAllocFloatStatus() >>> self.alloc_status = ops.NPUAllocFloatStatus()
>>> get_status = ops.NPUGetFloatStatus() >>> self.get_status = ops.NPUGetFloatStatus()
>>> clear_status = ops.NPUClearFloatStatus() >>> self.clear_status = ops.NPUClearFloatStatus()
>>> init = alloc_status() >>> init = self.alloc_status()
>>> flag = get_status(init) >>> init = F.Depend(init, input) # Ensure clear_status after input
>>> clear_status(init) >>> clear_status = self.clear_status(init)
>>> input = F.Depend(input, clear_status) # Ensure your compute after clear_status
>>> output = Compute(input)
>>> init = F.Depend(init, output)
>>> flag = self.get_status(init) # Ensure get_status after your compute
>>> self.clear_status(init)
>>> print(init) >>> print(init)
[0. 0. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 0. 0. 0.]
""" """
@ -4178,7 +4194,6 @@ class NPUClearFloatStatus(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
"""Initialize NPUClearFloatStatus""" """Initialize NPUClearFloatStatus"""
self.add_prim_attr("_side_effect_flag", True)
def infer_shape(self, x_shape): def infer_shape(self, x_shape):
cls_name = self.name cls_name = self.name

View File

@ -189,9 +189,9 @@ class PanguAlphaTrainPipelineWithLossScaleCell(nn.Cell):
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
self.cast = P.Cast() self.cast = P.Cast()
self.alloc_status = P.NPUAllocFloatStatus().add_prim_attr("_side_effect_flag", False) self.alloc_status = P.NPUAllocFloatStatus()
self.get_status = P.NPUGetFloatStatus().add_prim_attr("_side_effect_flag", False) self.get_status = P.NPUGetFloatStatus()
self.clear_before_grad = P.NPUClearFloatStatus().add_prim_attr("_side_effect_flag", False) self.clear_before_grad = P.NPUClearFloatStatus()
self.reduce_sum = P.ReduceSum(keep_dims=False) self.reduce_sum = P.ReduceSum(keep_dims=False)
self.base = Tensor(1, mstype.float32) self.base = Tensor(1, mstype.float32)
self.less_equal = P.LessEqual() self.less_equal = P.LessEqual()