diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 8b047577bad..a280cc7c0a4 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -863,9 +863,9 @@ class Cell(Cell_): param.set_cast_dtype(mstype.float32) elif mixed_type == MixedPrecisionType.FP16: param.set_cast_dtype(mstype.float16) - elif hasattr(param, "set_cast_dtype"): - # retest dtype - param.set_cast_dtype() + elif hasattr(param, "set_cast_dtype"): + # retest dtype + param.set_cast_dtype() return param def insert_child_to_cell(self, child_name, child_cell):