forked from mindspore-Ecosystem/mindspore
!22080 revert remove_redundant_depend in some network scripts
Merge pull request !22080 from zhouneng/code_docs_fix_issue_I46EVQ
This commit is contained in:
commit
f3d6534077
|
@ -310,8 +310,8 @@ class TrainingWrapper(nn.Cell):
|
|||
else:
|
||||
cond = self.less_equal(self.base, flag_sum)
|
||||
|
||||
self.optimizer(grads)
|
||||
return (loss, cond, sens)
|
||||
ret = (loss, cond, sens)
|
||||
return F.depend(ret, self.optimizer(grads))
|
||||
|
||||
|
||||
class CenterFaceWithNms(nn.Cell):
|
||||
|
|
|
@ -20,6 +20,7 @@ from mindspore.nn.layer.activation import get_activation
|
|||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
|
@ -260,8 +261,7 @@ class TrainStepWrap(nn.Cell):
|
|||
if self.reducer_flag:
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
self.optimizer(grads)
|
||||
return loss
|
||||
return F.depend(loss, self.optimizer(grads))
|
||||
|
||||
|
||||
class PredictWithSigmoid(nn.Cell):
|
||||
|
|
Loading…
Reference in New Issue