revert remove_redundant_depend
This commit is contained in:
parent
111d1a9a61
commit
fbe4c5278a
|
@ -19,6 +19,7 @@ import numpy as np
|
|||
import mindspore.nn as nn
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore import ParameterTuple
|
||||
from mindspore.train.callback import Callback
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
|
@ -146,5 +147,5 @@ class TrainOneStepCell(nn.Cell):
|
|||
grads = self.grad(self.network, weights)(x, img_shape, gt_bboxe, gt_label, gt_num, self.sens)
|
||||
if self.reduce_flag:
|
||||
grads = self.grad_reducer(grads)
|
||||
self.optimizer(grads)
|
||||
return loss
|
||||
|
||||
return F.depend(loss, self.optimizer(grads))
|
||||
|
|
|
@ -19,6 +19,7 @@ import numpy as np
|
|||
import mindspore.nn as nn
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore import ParameterTuple
|
||||
from mindspore.train.callback import Callback
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
|
@ -145,5 +146,5 @@ class TrainOneStepCell(nn.Cell):
|
|||
grads = self.grad(self.network, weights)(x, img_shape, gt_bboxe, gt_label, gt_num, gt_mask, self.sens)
|
||||
if self.reduce_flag:
|
||||
grads = self.grad_reducer(grads)
|
||||
self.optimizer(grads)
|
||||
return loss
|
||||
|
||||
return F.depend(loss, self.optimizer(grads))
|
||||
|
|
Loading…
Reference in New Issue