forked from mindspore-Ecosystem/mindspore
!1907 fix pynative param bug
Merge pull request !1907 from flywind/fix_pynative_bug
This commit is contained in:
commit
095e41eff3
|
@ -202,9 +202,13 @@ class Cell:
|
|||
out = self.compile_and_run(*inputs)
|
||||
return out
|
||||
self.init_parameters_data()
|
||||
orign_grad = []
|
||||
if self.requires_grad is True:
|
||||
_pynative_exec.set_grad_flag(True)
|
||||
_pynative_exec.new_graph(self, *inputs)
|
||||
for cell in self.cells():
|
||||
orign_grad.append(cell.requires_grad)
|
||||
cell.set_grad(True)
|
||||
else:
|
||||
_pynative_exec.set_grad_flag(False)
|
||||
if self.enable_hook:
|
||||
|
@ -215,6 +219,8 @@ class Cell:
|
|||
output = output.data
|
||||
if self.requires_grad is True:
|
||||
_pynative_exec.end_graph(self, output, *inputs)
|
||||
for i, cell in enumerate(self.cells()):
|
||||
cell.set_grad(orign_grad[i])
|
||||
self._is_run = True
|
||||
return output
|
||||
|
||||
|
@ -744,7 +750,7 @@ class Cell:
|
|||
return self
|
||||
|
||||
def set_grad(self, mode=True):
|
||||
self.add_flags_recursive(requires_grad=mode)
|
||||
self.requires_grad = mode
|
||||
return self
|
||||
|
||||
def set_train(self, mode=True):
|
||||
|
|
Loading…
Reference in New Issue