fix eval in amp

This commit is contained in:
Wei Luning 2020-07-30 10:26:39 +08:00
parent 7f3926429b
commit ca4b2f6c0b
1 changed files with 1 additions and 1 deletions

View File

@ -174,7 +174,7 @@ class Model:
else: else:
if self._loss_fn is None: if self._loss_fn is None:
raise ValueError("loss_fn can not be None.") raise ValueError("loss_fn can not be None.")
self._eval_network = nn.WithEvalCell(self._network, self._loss_fn, self._amp_level in ["O0", "O3"]) self._eval_network = nn.WithEvalCell(self._network, self._loss_fn, self._amp_level in ["O2", "O3"])
self._eval_indexes = [0, 1, 2] self._eval_indexes = [0, 1, 2]
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):