forked from mindspore-Ecosystem/mindspore
!17694 fix unet multiclass bug
From: @zhao_ting_v Reviewed-by: @wuxuejian,@c_34 Signed-off-by: @wuxuejian,@c_34
This commit is contained in:
commit
b5cfded805
|
@ -86,9 +86,10 @@ class CrossEntropyWithLogits(MyLoss):
|
|||
logits = self.transpose_fn(logits, (0, 2, 3, 1))
|
||||
logits = self.cast(logits, mindspore.float32)
|
||||
label = self.transpose_fn(label, (0, 2, 3, 1))
|
||||
_, _, _, c = F.Shape()(label)
|
||||
|
||||
loss = self.reduce_mean(
|
||||
self.softmax_cross_entropy_loss(self.reshape_fn(logits, (-1, 2)), self.reshape_fn(label, (-1, 2))))
|
||||
self.softmax_cross_entropy_loss(self.reshape_fn(logits, (-1, c)), self.reshape_fn(label, (-1, c))))
|
||||
return self.get_loss(loss)
|
||||
|
||||
class MultiCrossEntropyWithLogits(nn.Cell):
|
||||
|
|
Loading…
Reference in New Issue