forked from mindspore-Ecosystem/mindspore
!23072 Clean side_effect_flag
Merge pull request !23072 from Margaret_wangrui/clean_side_effect_flag
This commit is contained in:
commit
90effa5821
|
@ -4087,7 +4087,6 @@ class NPUAllocFloatStatus(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize NPUAllocFloatStatus"""
|
||||
self.add_prim_attr("_side_effect_flag", True)
|
||||
|
||||
def infer_shape(self):
|
||||
return [8]
|
||||
|
@ -4103,6 +4102,9 @@ class NPUGetFloatStatus(PrimitiveWithInfer):
|
|||
The flag is a tensor whose shape is `(8,)` and data type is `mindspore.dtype.float32`.
|
||||
If the sum of the flag equals to 0, there is no overflow happened. If the sum of the flag is bigger than 0, there
|
||||
is overflow happened.
|
||||
In addition, there are strict sequencing requirements for use, i.e., before using the NPUGetFloatStatus operator,
|
||||
need to ensure that the NPUClearFlotStatus and your compute has been executed.
|
||||
We use Depend to ensure the execution order.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - The output tensor of `NPUAllocFloatStatus`.
|
||||
|
@ -4120,10 +4122,17 @@ class NPUGetFloatStatus(PrimitiveWithInfer):
|
|||
``Ascend``
|
||||
|
||||
Examples:
|
||||
>>> alloc_status = ops.NPUAllocFloatStatus()
|
||||
>>> get_status = ops.NPUGetFloatStatus()
|
||||
>>> init = alloc_status()
|
||||
>>> get_status(init)
|
||||
>>> self.alloc_status = ops.NPUAllocFloatStatus()
|
||||
>>> self.get_status = ops.NPUGetFloatStatus()
|
||||
>>> self.clear_status = ops.NPUClearFloatStatus()
|
||||
>>> init = self.alloc_status()
|
||||
>>> init = F.Depend(init, input) # Ensure clear_status after input
|
||||
>>> clear_status = self.clear_status(init)
|
||||
>>> input = F.Depend(input, clear_status) # Ensure your compute after clear_status
|
||||
>>> output = Compute(input)
|
||||
>>> init = F.Depend(init, output)
|
||||
>>> flag = self.get_status(init) # Ensure get_status after your compute
|
||||
>>> self.clear_status(init)
|
||||
>>> print(init)
|
||||
[0. 0. 0. 0. 0. 0. 0. 0.]
|
||||
"""
|
||||
|
@ -4131,7 +4140,6 @@ class NPUGetFloatStatus(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize NPUGetFloatStatus"""
|
||||
self.add_prim_attr("_side_effect_flag", True)
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
cls_name = self.name
|
||||
|
@ -4151,6 +4159,9 @@ class NPUClearFloatStatus(PrimitiveWithInfer):
|
|||
Note:
|
||||
The flag is in the register on the `Ascend` device. It will be reset and can not be reused again after the
|
||||
`NPUClearFloatStatus` is called.
|
||||
In addition, there are strict sequencing requirements for use, i.e., before using the NPUGetFloatStatus
|
||||
operator, need to ensure that the NPUClearFlotStatus and your compute has been executed.
|
||||
We use Depend to ensure the execution order.
|
||||
|
||||
Examples: see `NPUGetFloatStatus`.
|
||||
|
||||
|
@ -4165,12 +4176,17 @@ class NPUClearFloatStatus(PrimitiveWithInfer):
|
|||
``Ascend``
|
||||
|
||||
Examples:
|
||||
>>> alloc_status = ops.NPUAllocFloatStatus()
|
||||
>>> get_status = ops.NPUGetFloatStatus()
|
||||
>>> clear_status = ops.NPUClearFloatStatus()
|
||||
>>> init = alloc_status()
|
||||
>>> flag = get_status(init)
|
||||
>>> clear_status(init)
|
||||
>>> self.alloc_status = ops.NPUAllocFloatStatus()
|
||||
>>> self.get_status = ops.NPUGetFloatStatus()
|
||||
>>> self.clear_status = ops.NPUClearFloatStatus()
|
||||
>>> init = self.alloc_status()
|
||||
>>> init = F.Depend(init, input) # Ensure clear_status after input
|
||||
>>> clear_status = self.clear_status(init)
|
||||
>>> input = F.Depend(input, clear_status) # Ensure your compute after clear_status
|
||||
>>> output = Compute(input)
|
||||
>>> init = F.Depend(init, output)
|
||||
>>> flag = self.get_status(init) # Ensure get_status after your compute
|
||||
>>> self.clear_status(init)
|
||||
>>> print(init)
|
||||
[0. 0. 0. 0. 0. 0. 0. 0.]
|
||||
"""
|
||||
|
@ -4178,7 +4194,6 @@ class NPUClearFloatStatus(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize NPUClearFloatStatus"""
|
||||
self.add_prim_attr("_side_effect_flag", True)
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
cls_name = self.name
|
||||
|
|
|
@ -189,9 +189,9 @@ class PanguAlphaTrainPipelineWithLossScaleCell(nn.Cell):
|
|||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
|
||||
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
||||
self.cast = P.Cast()
|
||||
self.alloc_status = P.NPUAllocFloatStatus().add_prim_attr("_side_effect_flag", False)
|
||||
self.get_status = P.NPUGetFloatStatus().add_prim_attr("_side_effect_flag", False)
|
||||
self.clear_before_grad = P.NPUClearFloatStatus().add_prim_attr("_side_effect_flag", False)
|
||||
self.alloc_status = P.NPUAllocFloatStatus()
|
||||
self.get_status = P.NPUGetFloatStatus()
|
||||
self.clear_before_grad = P.NPUClearFloatStatus()
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||
self.base = Tensor(1, mstype.float32)
|
||||
self.less_equal = P.LessEqual()
|
||||
|
|
Loading…
Reference in New Issue