From abf019d95f41678148586e3a766565fd17fafa92 Mon Sep 17 00:00:00 2001 From: chujinjin Date: Sat, 13 Nov 2021 16:57:19 +0800 Subject: [PATCH] fix mixed precision for pynative --- mindspore/nn/cell.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 8b047577bad..a280cc7c0a4 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -863,9 +863,9 @@ class Cell(Cell_): param.set_cast_dtype(mstype.float32) elif mixed_type == MixedPrecisionType.FP16: param.set_cast_dtype(mstype.float16) - elif hasattr(param, "set_cast_dtype"): - # retest dtype - param.set_cast_dtype() + elif hasattr(param, "set_cast_dtype"): + # retest dtype + param.set_cast_dtype() return param def insert_child_to_cell(self, child_name, child_cell):