forked from mindspore-Ecosystem/mindspore
!17694 fix unet multiclass bug
From: @zhao_ting_v Reviewed-by: @wuxuejian,@c_34 Signed-off-by: @wuxuejian,@c_34
This commit is contained in:
commit
b5cfded805
|
@ -86,9 +86,10 @@ class CrossEntropyWithLogits(MyLoss):
|
||||||
logits = self.transpose_fn(logits, (0, 2, 3, 1))
|
logits = self.transpose_fn(logits, (0, 2, 3, 1))
|
||||||
logits = self.cast(logits, mindspore.float32)
|
logits = self.cast(logits, mindspore.float32)
|
||||||
label = self.transpose_fn(label, (0, 2, 3, 1))
|
label = self.transpose_fn(label, (0, 2, 3, 1))
|
||||||
|
_, _, _, c = F.Shape()(label)
|
||||||
|
|
||||||
loss = self.reduce_mean(
|
loss = self.reduce_mean(
|
||||||
self.softmax_cross_entropy_loss(self.reshape_fn(logits, (-1, 2)), self.reshape_fn(label, (-1, 2))))
|
self.softmax_cross_entropy_loss(self.reshape_fn(logits, (-1, c)), self.reshape_fn(label, (-1, c))))
|
||||||
return self.get_loss(loss)
|
return self.get_loss(loss)
|
||||||
|
|
||||||
class MultiCrossEntropyWithLogits(nn.Cell):
|
class MultiCrossEntropyWithLogits(nn.Cell):
|
||||||
|
|
Loading…
Reference in New Issue