!26251 fix error for mixed precision on pynative mode

Merge pull request !26251 from chujinjin/fix_mixed_precision_for_pynative
This commit is contained in:
i-robot 2021-11-15 03:26:17 +00:00 committed by Gitee
commit 9392fc38bc
1 changed files with 3 additions and 3 deletions

View File

@ -863,9 +863,9 @@ class Cell(Cell_):
param.set_cast_dtype(mstype.float32) param.set_cast_dtype(mstype.float32)
elif mixed_type == MixedPrecisionType.FP16: elif mixed_type == MixedPrecisionType.FP16:
param.set_cast_dtype(mstype.float16) param.set_cast_dtype(mstype.float16)
elif hasattr(param, "set_cast_dtype"): elif hasattr(param, "set_cast_dtype"):
# retest dtype # retest dtype
param.set_cast_dtype() param.set_cast_dtype()
return param return param
def insert_child_to_cell(self, child_name, child_cell): def insert_child_to_cell(self, child_name, child_cell):