forked from mindspore-Ecosystem/mindspore
move argmax from host to device
This commit is contained in:
parent
34864fbc56
commit
7ae7505caa
|
@ -381,6 +381,7 @@ class DeepLabV3(nn.Cell):
|
||||||
self.concat = P.Concat(axis=2)
|
self.concat = P.Concat(axis=2)
|
||||||
self.expand_dims = P.ExpandDims()
|
self.expand_dims = P.ExpandDims()
|
||||||
self.reduce_mean = P.ReduceMean()
|
self.reduce_mean = P.ReduceMean()
|
||||||
|
self.argmax = P.Argmax(axis=1)
|
||||||
self.sample_common = P.ResizeBilinear((int(feature_shape[2]),
|
self.sample_common = P.ResizeBilinear((int(feature_shape[2]),
|
||||||
int(feature_shape[3])),
|
int(feature_shape[3])),
|
||||||
align_corners=True)
|
align_corners=True)
|
||||||
|
@ -419,6 +420,8 @@ class DeepLabV3(nn.Cell):
|
||||||
logits_i = self.expand_dims(logits_i, 2)
|
logits_i = self.expand_dims(logits_i, 2)
|
||||||
logits = self.concat((logits, logits_i))
|
logits = self.concat((logits, logits_i))
|
||||||
logits = self.reduce_mean(logits, 2)
|
logits = self.reduce_mean(logits, 2)
|
||||||
|
if not self.training:
|
||||||
|
logits = self.argmax(logits)
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -42,6 +42,8 @@ class OhemLoss(nn.Cell):
|
||||||
self.loss_weight = 1.0
|
self.loss_weight = 1.0
|
||||||
|
|
||||||
def construct(self, logits, labels):
|
def construct(self, logits, labels):
|
||||||
|
if not self.training:
|
||||||
|
return 0
|
||||||
logits = self.transpose(logits, (0, 2, 3, 1))
|
logits = self.transpose(logits, (0, 2, 3, 1))
|
||||||
logits = self.reshape(logits, (-1, self.num))
|
logits = self.reshape(logits, (-1, self.num))
|
||||||
labels = F.cast(labels, mstype.int32)
|
labels = F.cast(labels, mstype.int32)
|
||||||
|
|
|
@ -50,10 +50,7 @@ class MiouPrecision(Metric):
|
||||||
raise ValueError('Need 2 inputs (y_pred, y), but got {}'.format(len(inputs)))
|
raise ValueError('Need 2 inputs (y_pred, y), but got {}'.format(len(inputs)))
|
||||||
predict_in = self._convert_data(inputs[0])
|
predict_in = self._convert_data(inputs[0])
|
||||||
label_in = self._convert_data(inputs[1])
|
label_in = self._convert_data(inputs[1])
|
||||||
if predict_in.shape[1] != self._num_class:
|
pred = predict_in
|
||||||
raise ValueError('Class number not match, last input data contain {} classes, but current data contain {} '
|
|
||||||
'classes'.format(self._num_class, predict_in.shape[1]))
|
|
||||||
pred = np.argmax(predict_in, axis=1)
|
|
||||||
label = label_in
|
label = label_in
|
||||||
if len(label.flatten()) != len(pred.flatten()):
|
if len(label.flatten()) != len(pred.flatten()):
|
||||||
print('Skipping: len(gt) = {:d}, len(pred) = {:d}'.format(len(label.flatten()), len(pred.flatten())))
|
print('Skipping: len(gt) = {:d}, len(pred) = {:d}'.format(len(label.flatten()), len(pred.flatten())))
|
||||||
|
|
Loading…
Reference in New Issue