!26251 fix error for mixed precision on pynative mode
Merge pull request !26251 from chujinjin/fix_mixed_precision_for_pynative
This commit is contained in:
commit
9392fc38bc
|
@ -863,9 +863,9 @@ class Cell(Cell_):
|
||||||
param.set_cast_dtype(mstype.float32)
|
param.set_cast_dtype(mstype.float32)
|
||||||
elif mixed_type == MixedPrecisionType.FP16:
|
elif mixed_type == MixedPrecisionType.FP16:
|
||||||
param.set_cast_dtype(mstype.float16)
|
param.set_cast_dtype(mstype.float16)
|
||||||
elif hasattr(param, "set_cast_dtype"):
|
elif hasattr(param, "set_cast_dtype"):
|
||||||
# retest dtype
|
# retest dtype
|
||||||
param.set_cast_dtype()
|
param.set_cast_dtype()
|
||||||
return param
|
return param
|
||||||
|
|
||||||
def insert_child_to_cell(self, child_name, child_cell):
|
def insert_child_to_cell(self, child_name, child_cell):
|
||||||
|
|
Loading…
Reference in New Issue