!23072 Clean side_effect_flag

Merge pull request !23072 from Margaret_wangrui/clean_side_effect_flag
This commit is contained in:
i-robot 2021-09-09 02:47:34 +00:00 committed by Gitee
commit 90effa5821
2 changed files with 31 additions and 16 deletions

View File

@ -4087,7 +4087,6 @@ class NPUAllocFloatStatus(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
"""Initialize NPUAllocFloatStatus"""
self.add_prim_attr("_side_effect_flag", True)
def infer_shape(self):
return [8]
@ -4103,6 +4102,9 @@ class NPUGetFloatStatus(PrimitiveWithInfer):
The flag is a tensor whose shape is `(8,)` and data type is `mindspore.dtype.float32`.
If the sum of the flag equals to 0, there is no overflow happened. If the sum of the flag is bigger than 0, there
is overflow happened.
In addition, there are strict sequencing requirements for use, i.e., before using the NPUGetFloatStatus operator,
need to ensure that the NPUClearFlotStatus and your compute has been executed.
We use Depend to ensure the execution order.
Inputs:
- **x** (Tensor) - The output tensor of `NPUAllocFloatStatus`.
@ -4120,10 +4122,17 @@ class NPUGetFloatStatus(PrimitiveWithInfer):
``Ascend``
Examples:
>>> alloc_status = ops.NPUAllocFloatStatus()
>>> get_status = ops.NPUGetFloatStatus()
>>> init = alloc_status()
>>> get_status(init)
>>> self.alloc_status = ops.NPUAllocFloatStatus()
>>> self.get_status = ops.NPUGetFloatStatus()
>>> self.clear_status = ops.NPUClearFloatStatus()
>>> init = self.alloc_status()
>>> init = F.Depend(init, input) # Ensure clear_status after input
>>> clear_status = self.clear_status(init)
>>> input = F.Depend(input, clear_status) # Ensure your compute after clear_status
>>> output = Compute(input)
>>> init = F.Depend(init, output)
>>> flag = self.get_status(init) # Ensure get_status after your compute
>>> self.clear_status(init)
>>> print(init)
[0. 0. 0. 0. 0. 0. 0. 0.]
"""
@ -4131,7 +4140,6 @@ class NPUGetFloatStatus(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
"""Initialize NPUGetFloatStatus"""
self.add_prim_attr("_side_effect_flag", True)
def infer_shape(self, x_shape):
cls_name = self.name
@ -4151,6 +4159,9 @@ class NPUClearFloatStatus(PrimitiveWithInfer):
Note:
The flag is in the register on the `Ascend` device. It will be reset and can not be reused again after the
`NPUClearFloatStatus` is called.
In addition, there are strict sequencing requirements for use, i.e., before using the NPUGetFloatStatus
operator, need to ensure that the NPUClearFlotStatus and your compute has been executed.
We use Depend to ensure the execution order.
Examples: see `NPUGetFloatStatus`.
@ -4165,12 +4176,17 @@ class NPUClearFloatStatus(PrimitiveWithInfer):
``Ascend``
Examples:
>>> alloc_status = ops.NPUAllocFloatStatus()
>>> get_status = ops.NPUGetFloatStatus()
>>> clear_status = ops.NPUClearFloatStatus()
>>> init = alloc_status()
>>> flag = get_status(init)
>>> clear_status(init)
>>> self.alloc_status = ops.NPUAllocFloatStatus()
>>> self.get_status = ops.NPUGetFloatStatus()
>>> self.clear_status = ops.NPUClearFloatStatus()
>>> init = self.alloc_status()
>>> init = F.Depend(init, input) # Ensure clear_status after input
>>> clear_status = self.clear_status(init)
>>> input = F.Depend(input, clear_status) # Ensure your compute after clear_status
>>> output = Compute(input)
>>> init = F.Depend(init, output)
>>> flag = self.get_status(init) # Ensure get_status after your compute
>>> self.clear_status(init)
>>> print(init)
[0. 0. 0. 0. 0. 0. 0. 0.]
"""
@ -4178,7 +4194,6 @@ class NPUClearFloatStatus(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
"""Initialize NPUClearFloatStatus"""
self.add_prim_attr("_side_effect_flag", True)
def infer_shape(self, x_shape):
cls_name = self.name

View File

@ -189,9 +189,9 @@ class PanguAlphaTrainPipelineWithLossScaleCell(nn.Cell):
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
self.cast = P.Cast()
self.alloc_status = P.NPUAllocFloatStatus().add_prim_attr("_side_effect_flag", False)
self.get_status = P.NPUGetFloatStatus().add_prim_attr("_side_effect_flag", False)
self.clear_before_grad = P.NPUClearFloatStatus().add_prim_attr("_side_effect_flag", False)
self.alloc_status = P.NPUAllocFloatStatus()
self.get_status = P.NPUGetFloatStatus()
self.clear_before_grad = P.NPUClearFloatStatus()
self.reduce_sum = P.ReduceSum(keep_dims=False)
self.base = Tensor(1, mstype.float32)
self.less_equal = P.LessEqual()