forked from mindspore-Ecosystem/mindspore
recovery-already-run
This commit is contained in:
parent
61ed05f133
commit
04392f3120
|
@ -361,6 +361,7 @@ class Cell(Cell_):
|
|||
_pynative_exec.end_graph(self, output, *inputs, **kwargs)
|
||||
for i, cell in enumerate(self.cells()):
|
||||
cell.set_grad(origin_grad[i])
|
||||
self._already_run = True
|
||||
return output
|
||||
|
||||
def _add_attr(self, name, value):
|
||||
|
|
|
@ -38,7 +38,6 @@ random.seed(1)
|
|||
np.random.seed(1)
|
||||
ds.config.set_seed(1)
|
||||
|
||||
|
||||
grad_by_list = CP.GradOperation(get_by_list=True)
|
||||
|
||||
|
||||
|
@ -404,10 +403,10 @@ def test_pynative_resnet50():
|
|||
step = step + 1
|
||||
if step > max_step:
|
||||
break
|
||||
start_time = time.time()
|
||||
input_data = element["image"]
|
||||
input_label = element["label"]
|
||||
loss_output = net_with_criterion(input_data, input_label)
|
||||
start_time = time.time()
|
||||
grads = train_network(input_data, input_label)
|
||||
optimizer(grads)
|
||||
end_time = time.time()
|
||||
|
|
|
@ -403,10 +403,10 @@ def test_pynative_resnet50():
|
|||
step = step + 1
|
||||
if step > max_step:
|
||||
break
|
||||
start_time = time.time()
|
||||
input_data = element["image"]
|
||||
input_label = element["label"]
|
||||
loss_output = net_with_criterion(input_data, input_label)
|
||||
start_time = time.time()
|
||||
grads = train_network(input_data, input_label)
|
||||
optimizer(grads)
|
||||
end_time = time.time()
|
||||
|
|
Loading…
Reference in New Issue