!26251 fix error for mixed precision on pynative mode
Merge pull request !26251 from chujinjin/fix_mixed_precision_for_pynative
This commit is contained in:
commit
9392fc38bc
|
@ -863,9 +863,9 @@ class Cell(Cell_):
|
|||
param.set_cast_dtype(mstype.float32)
|
||||
elif mixed_type == MixedPrecisionType.FP16:
|
||||
param.set_cast_dtype(mstype.float16)
|
||||
elif hasattr(param, "set_cast_dtype"):
|
||||
# retest dtype
|
||||
param.set_cast_dtype()
|
||||
elif hasattr(param, "set_cast_dtype"):
|
||||
# retest dtype
|
||||
param.set_cast_dtype()
|
||||
return param
|
||||
|
||||
def insert_child_to_cell(self, child_name, child_cell):
|
||||
|
|
Loading…
Reference in New Issue