revert remove_redundant_depend

This commit is contained in:
huangbingjian 2021-08-27 14:24:46 +08:00
parent 111d1a9a61
commit fbe4c5278a
2 changed files with 6 additions and 4 deletions

View File

@ -19,6 +19,7 @@ import numpy as np
import mindspore.nn as nn
from mindspore.common.tensor import Tensor
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore import ParameterTuple
from mindspore.train.callback import Callback
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
@ -146,5 +147,5 @@ class TrainOneStepCell(nn.Cell):
grads = self.grad(self.network, weights)(x, img_shape, gt_bboxe, gt_label, gt_num, self.sens)
if self.reduce_flag:
grads = self.grad_reducer(grads)
self.optimizer(grads)
return loss
return F.depend(loss, self.optimizer(grads))

View File

@ -19,6 +19,7 @@ import numpy as np
import mindspore.nn as nn
from mindspore.common.tensor import Tensor
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore import ParameterTuple
from mindspore.train.callback import Callback
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
@ -145,5 +146,5 @@ class TrainOneStepCell(nn.Cell):
grads = self.grad(self.network, weights)(x, img_shape, gt_bboxe, gt_label, gt_num, gt_mask, self.sens)
if self.reduce_flag:
grads = self.grad_reducer(grads)
self.optimizer(grads)
return loss
return F.depend(loss, self.optimizer(grads))