diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 1173738100e..f79dce9777b 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -202,9 +202,13 @@ class Cell: out = self.compile_and_run(*inputs) return out self.init_parameters_data() + orign_grad = [] if self.requires_grad is True: _pynative_exec.set_grad_flag(True) _pynative_exec.new_graph(self, *inputs) + for cell in self.cells(): + orign_grad.append(cell.requires_grad) + cell.set_grad(True) else: _pynative_exec.set_grad_flag(False) if self.enable_hook: @@ -215,6 +219,8 @@ class Cell: output = output.data if self.requires_grad is True: _pynative_exec.end_graph(self, output, *inputs) + for i, cell in enumerate(self.cells()): + cell.set_grad(orign_grad[i]) self._is_run = True return output @@ -744,7 +750,7 @@ class Cell: return self def set_grad(self, mode=True): - self.add_flags_recursive(requires_grad=mode) + self.requires_grad = mode return self def set_train(self, mode=True):