diff --git a/model_zoo/official/cv/unet/src/loss.py b/model_zoo/official/cv/unet/src/loss.py index fd0b6b8f9ff..4f60319efd1 100644 --- a/model_zoo/official/cv/unet/src/loss.py +++ b/model_zoo/official/cv/unet/src/loss.py @@ -86,9 +86,10 @@ class CrossEntropyWithLogits(MyLoss): logits = self.transpose_fn(logits, (0, 2, 3, 1)) logits = self.cast(logits, mindspore.float32) label = self.transpose_fn(label, (0, 2, 3, 1)) + _, _, _, c = F.Shape()(label) loss = self.reduce_mean( - self.softmax_cross_entropy_loss(self.reshape_fn(logits, (-1, 2)), self.reshape_fn(label, (-1, 2)))) + self.softmax_cross_entropy_loss(self.reshape_fn(logits, (-1, c)), self.reshape_fn(label, (-1, c)))) return self.get_loss(loss) class MultiCrossEntropyWithLogits(nn.Cell):