forked from mindspore-Ecosystem/mindspore
move argmax from host to device
This commit is contained in:
parent
34864fbc56
commit
7ae7505caa
|
@ -381,6 +381,7 @@ class DeepLabV3(nn.Cell):
|
|||
self.concat = P.Concat(axis=2)
|
||||
self.expand_dims = P.ExpandDims()
|
||||
self.reduce_mean = P.ReduceMean()
|
||||
self.argmax = P.Argmax(axis=1)
|
||||
self.sample_common = P.ResizeBilinear((int(feature_shape[2]),
|
||||
int(feature_shape[3])),
|
||||
align_corners=True)
|
||||
|
@ -419,6 +420,8 @@ class DeepLabV3(nn.Cell):
|
|||
logits_i = self.expand_dims(logits_i, 2)
|
||||
logits = self.concat((logits, logits_i))
|
||||
logits = self.reduce_mean(logits, 2)
|
||||
if not self.training:
|
||||
logits = self.argmax(logits)
|
||||
return logits
|
||||
|
||||
|
||||
|
|
|
@ -42,6 +42,8 @@ class OhemLoss(nn.Cell):
|
|||
self.loss_weight = 1.0
|
||||
|
||||
def construct(self, logits, labels):
|
||||
if not self.training:
|
||||
return 0
|
||||
logits = self.transpose(logits, (0, 2, 3, 1))
|
||||
logits = self.reshape(logits, (-1, self.num))
|
||||
labels = F.cast(labels, mstype.int32)
|
||||
|
|
|
@ -50,10 +50,7 @@ class MiouPrecision(Metric):
|
|||
raise ValueError('Need 2 inputs (y_pred, y), but got {}'.format(len(inputs)))
|
||||
predict_in = self._convert_data(inputs[0])
|
||||
label_in = self._convert_data(inputs[1])
|
||||
if predict_in.shape[1] != self._num_class:
|
||||
raise ValueError('Class number not match, last input data contain {} classes, but current data contain {} '
|
||||
'classes'.format(self._num_class, predict_in.shape[1]))
|
||||
pred = np.argmax(predict_in, axis=1)
|
||||
pred = predict_in
|
||||
label = label_in
|
||||
if len(label.flatten()) != len(pred.flatten()):
|
||||
print('Skipping: len(gt) = {:d}, len(pred) = {:d}'.format(len(label.flatten()), len(pred.flatten())))
|
||||
|
|
Loading…
Reference in New Issue