!17694 fix unet multiclass bug

From: @zhao_ting_v
Reviewed-by: @wuxuejian,@c_34
Signed-off-by: @wuxuejian,@c_34
This commit is contained in:
mindspore-ci-bot 2021-06-07 09:21:38 +08:00 committed by Gitee
commit b5cfded805
1 changed files with 2 additions and 1 deletions

View File

@ -86,9 +86,10 @@ class CrossEntropyWithLogits(MyLoss):
logits = self.transpose_fn(logits, (0, 2, 3, 1))
logits = self.cast(logits, mindspore.float32)
label = self.transpose_fn(label, (0, 2, 3, 1))
_, _, _, c = F.Shape()(label)
loss = self.reduce_mean(
self.softmax_cross_entropy_loss(self.reshape_fn(logits, (-1, 2)), self.reshape_fn(label, (-1, 2))))
self.softmax_cross_entropy_loss(self.reshape_fn(logits, (-1, c)), self.reshape_fn(label, (-1, c))))
return self.get_loss(loss)
class MultiCrossEntropyWithLogits(nn.Cell):