clean side_effect_flag

This commit is contained in:
Margaret_wangrui 2021-09-08 15:21:21 +08:00
parent dd1b8390c7
commit 34f443ca18
2 changed files with 31 additions and 16 deletions

View File

@ -4087,7 +4087,6 @@ class NPUAllocFloatStatus(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
"""Initialize NPUAllocFloatStatus"""
self.add_prim_attr("_side_effect_flag", True)
def infer_shape(self):
return [8]
@ -4103,6 +4102,9 @@ class NPUGetFloatStatus(PrimitiveWithInfer):
The flag is a tensor whose shape is `(8,)` and data type is `mindspore.dtype.float32`.
If the sum of the flag equals to 0, there is no overflow happened. If the sum of the flag is bigger than 0, there
is overflow happened.
In addition, there are strict sequencing requirements for use, i.e., before using the NPUGetFloatStatus operator,
need to ensure that the NPUClearFlotStatus and your compute has been executed.
We use Depend to ensure the execution order.
Inputs:
- **x** (Tensor) - The output tensor of `NPUAllocFloatStatus`.
@ -4120,10 +4122,17 @@ class NPUGetFloatStatus(PrimitiveWithInfer):
``Ascend``
Examples:
>>> alloc_status = ops.NPUAllocFloatStatus()
>>> get_status = ops.NPUGetFloatStatus()
>>> init = alloc_status()
>>> get_status(init)
>>> self.alloc_status = ops.NPUAllocFloatStatus()
>>> self.get_status = ops.NPUGetFloatStatus()
>>> self.clear_status = ops.NPUClearFloatStatus()
>>> init = self.alloc_status()
>>> init = F.Depend(init, input) # Ensure clear_status after input
>>> clear_status = self.clear_status(init)
>>> input = F.Depend(input, clear_status) # Ensure your compute after clear_status
>>> output = Compute(input)
>>> init = F.Depend(init, output)
>>> flag = self.get_status(init) # Ensure get_status after your compute
>>> self.clear_status(init)
>>> print(init)
[0. 0. 0. 0. 0. 0. 0. 0.]
"""
@ -4131,7 +4140,6 @@ class NPUGetFloatStatus(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
"""Initialize NPUGetFloatStatus"""
self.add_prim_attr("_side_effect_flag", True)
def infer_shape(self, x_shape):
cls_name = self.name
@ -4151,6 +4159,9 @@ class NPUClearFloatStatus(PrimitiveWithInfer):
Note:
The flag is in the register on the `Ascend` device. It will be reset and can not be reused again after the
`NPUClearFloatStatus` is called.
In addition, there are strict sequencing requirements for use, i.e., before using the NPUGetFloatStatus
operator, need to ensure that the NPUClearFlotStatus and your compute has been executed.
We use Depend to ensure the execution order.
Examples: see `NPUGetFloatStatus`.
@ -4165,12 +4176,17 @@ class NPUClearFloatStatus(PrimitiveWithInfer):
``Ascend``
Examples:
>>> alloc_status = ops.NPUAllocFloatStatus()
>>> get_status = ops.NPUGetFloatStatus()
>>> clear_status = ops.NPUClearFloatStatus()
>>> init = alloc_status()
>>> flag = get_status(init)
>>> clear_status(init)
>>> self.alloc_status = ops.NPUAllocFloatStatus()
>>> self.get_status = ops.NPUGetFloatStatus()
>>> self.clear_status = ops.NPUClearFloatStatus()
>>> init = self.alloc_status()
>>> init = F.Depend(init, input) # Ensure clear_status after input
>>> clear_status = self.clear_status(init)
>>> input = F.Depend(input, clear_status) # Ensure your compute after clear_status
>>> output = Compute(input)
>>> init = F.Depend(init, output)
>>> flag = self.get_status(init) # Ensure get_status after your compute
>>> self.clear_status(init)
>>> print(init)
[0. 0. 0. 0. 0. 0. 0. 0.]
"""
@ -4178,7 +4194,6 @@ class NPUClearFloatStatus(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
"""Initialize NPUClearFloatStatus"""
self.add_prim_attr("_side_effect_flag", True)
def infer_shape(self, x_shape):
cls_name = self.name

View File

@ -189,9 +189,9 @@ class PanguAlphaTrainPipelineWithLossScaleCell(nn.Cell):
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
self.cast = P.Cast()
self.alloc_status = P.NPUAllocFloatStatus().add_prim_attr("_side_effect_flag", False)
self.get_status = P.NPUGetFloatStatus().add_prim_attr("_side_effect_flag", False)
self.clear_before_grad = P.NPUClearFloatStatus().add_prim_attr("_side_effect_flag", False)
self.alloc_status = P.NPUAllocFloatStatus()
self.get_status = P.NPUGetFloatStatus()
self.clear_before_grad = P.NPUClearFloatStatus()
self.reduce_sum = P.ReduceSum(keep_dims=False)
self.base = Tensor(1, mstype.float32)
self.less_equal = P.LessEqual()