fix unet multiclass bug
This commit is contained in:
parent
9179505446
commit
731ee5463a
|
@ -32,9 +32,10 @@ class CrossEntropyWithLogits(_Loss):
|
|||
logits = self.transpose_fn(logits, (0, 2, 3, 1))
|
||||
logits = self.cast(logits, mindspore.float32)
|
||||
label = self.transpose_fn(label, (0, 2, 3, 1))
|
||||
_, _, _, c = F.Shape()(label)
|
||||
|
||||
loss = self.reduce_mean(
|
||||
self.softmax_cross_entropy_loss(self.reshape_fn(logits, (-1, 2)), self.reshape_fn(label, (-1, 2))))
|
||||
self.softmax_cross_entropy_loss(self.reshape_fn(logits, (-1, c)), self.reshape_fn(label, (-1, c))))
|
||||
return self.get_loss(loss)
|
||||
|
||||
class MultiCrossEntropyWithLogits(nn.Cell):
|
||||
|
|
Loading…
Reference in New Issue