forked from mindspore-Ecosystem/mindspore
fix centernet loss error.
This commit is contained in:
parent
3539952b66
commit
21a2c7bcbe
|
@ -308,10 +308,7 @@ class CenterNetWithLossScaleCell(nn.Cell):
|
|||
cond = self.less_equal(self.base, flag_reduce)
|
||||
else:
|
||||
cond = self.less_equal(self.base, flag_sum)
|
||||
overflow = cond
|
||||
if overflow:
|
||||
succ = False
|
||||
else:
|
||||
|
||||
succ = self.optimizer(grads)
|
||||
ret = (loss, cond, scaling_sens)
|
||||
return ops.depend(ret, succ)
|
||||
|
|
|
@ -137,7 +137,7 @@ class GatherFeature(nn.Cell):
|
|||
self.gather_nd = ops.GatherD()
|
||||
self.expand_dims = ops.ExpandDims()
|
||||
else:
|
||||
self.gather_nd = ops.GatherND()
|
||||
self.gather_nd = ops.GatherNd()
|
||||
|
||||
def construct(self, feat, ind):
|
||||
"""gather by specified index"""
|
||||
|
|
Loading…
Reference in New Issue