From 731ee5463a550a088aba1cdfbd142502542fb8e4 Mon Sep 17 00:00:00 2001 From: zhaoting Date: Thu, 3 Jun 2021 09:51:21 +0800 Subject: [PATCH] fix unet multiclass bug --- model_zoo/official/cv/unet/src/loss.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/model_zoo/official/cv/unet/src/loss.py b/model_zoo/official/cv/unet/src/loss.py index 81c25288632..ac0d63f7fb9 100644 --- a/model_zoo/official/cv/unet/src/loss.py +++ b/model_zoo/official/cv/unet/src/loss.py @@ -32,9 +32,10 @@ class CrossEntropyWithLogits(_Loss): logits = self.transpose_fn(logits, (0, 2, 3, 1)) logits = self.cast(logits, mindspore.float32) label = self.transpose_fn(label, (0, 2, 3, 1)) + _, _, _, c = F.Shape()(label) loss = self.reduce_mean( - self.softmax_cross_entropy_loss(self.reshape_fn(logits, (-1, 2)), self.reshape_fn(label, (-1, 2)))) + self.softmax_cross_entropy_loss(self.reshape_fn(logits, (-1, c)), self.reshape_fn(label, (-1, c)))) return self.get_loss(loss) class MultiCrossEntropyWithLogits(nn.Cell):