fix centernet loss error.

This commit is contained in:
CaoJian 2021-03-20 21:42:25 +08:00
parent 3539952b66
commit 21a2c7bcbe
2 changed files with 3 additions and 6 deletions

View File

@ -308,11 +308,8 @@ class CenterNetWithLossScaleCell(nn.Cell):
cond = self.less_equal(self.base, flag_reduce)
else:
cond = self.less_equal(self.base, flag_sum)
overflow = cond
if overflow:
succ = False
else:
succ = self.optimizer(grads)
succ = self.optimizer(grads)
ret = (loss, cond, scaling_sens)
return ops.depend(ret, succ)

View File

@ -137,7 +137,7 @@ class GatherFeature(nn.Cell):
self.gather_nd = ops.GatherD()
self.expand_dims = ops.ExpandDims()
else:
self.gather_nd = ops.GatherND()
self.gather_nd = ops.GatherNd()
def construct(self, feat, ind):
"""gather by specified index"""