recovery-already-run

This commit is contained in:
lvliang 2020-12-22 21:55:43 +08:00
parent 61ed05f133
commit 04392f3120
3 changed files with 3 additions and 3 deletions

View File

@ -361,6 +361,7 @@ class Cell(Cell_):
_pynative_exec.end_graph(self, output, *inputs, **kwargs)
for i, cell in enumerate(self.cells()):
cell.set_grad(origin_grad[i])
self._already_run = True
return output
def _add_attr(self, name, value):

View File

@ -38,7 +38,6 @@ random.seed(1)
np.random.seed(1)
ds.config.set_seed(1)
grad_by_list = CP.GradOperation(get_by_list=True)
@ -404,10 +403,10 @@ def test_pynative_resnet50():
step = step + 1
if step > max_step:
break
start_time = time.time()
input_data = element["image"]
input_label = element["label"]
loss_output = net_with_criterion(input_data, input_label)
start_time = time.time()
grads = train_network(input_data, input_label)
optimizer(grads)
end_time = time.time()

View File

@ -403,10 +403,10 @@ def test_pynative_resnet50():
step = step + 1
if step > max_step:
break
start_time = time.time()
input_data = element["image"]
input_label = element["label"]
loss_output = net_with_criterion(input_data, input_label)
start_time = time.time()
grads = train_network(input_data, input_label)
optimizer(grads)
end_time = time.time()