move argmax from host to device

This commit is contained in:
zhouyaqiang 2020-08-10 09:39:52 +08:00
parent 34864fbc56
commit 7ae7505caa
3 changed files with 6 additions and 4 deletions

View File

@ -381,6 +381,7 @@ class DeepLabV3(nn.Cell):
self.concat = P.Concat(axis=2) self.concat = P.Concat(axis=2)
self.expand_dims = P.ExpandDims() self.expand_dims = P.ExpandDims()
self.reduce_mean = P.ReduceMean() self.reduce_mean = P.ReduceMean()
self.argmax = P.Argmax(axis=1)
self.sample_common = P.ResizeBilinear((int(feature_shape[2]), self.sample_common = P.ResizeBilinear((int(feature_shape[2]),
int(feature_shape[3])), int(feature_shape[3])),
align_corners=True) align_corners=True)
@ -419,6 +420,8 @@ class DeepLabV3(nn.Cell):
logits_i = self.expand_dims(logits_i, 2) logits_i = self.expand_dims(logits_i, 2)
logits = self.concat((logits, logits_i)) logits = self.concat((logits, logits_i))
logits = self.reduce_mean(logits, 2) logits = self.reduce_mean(logits, 2)
if not self.training:
logits = self.argmax(logits)
return logits return logits

View File

@ -42,6 +42,8 @@ class OhemLoss(nn.Cell):
self.loss_weight = 1.0 self.loss_weight = 1.0
def construct(self, logits, labels): def construct(self, logits, labels):
if not self.training:
return 0
logits = self.transpose(logits, (0, 2, 3, 1)) logits = self.transpose(logits, (0, 2, 3, 1))
logits = self.reshape(logits, (-1, self.num)) logits = self.reshape(logits, (-1, self.num))
labels = F.cast(labels, mstype.int32) labels = F.cast(labels, mstype.int32)

View File

@ -50,10 +50,7 @@ class MiouPrecision(Metric):
raise ValueError('Need 2 inputs (y_pred, y), but got {}'.format(len(inputs))) raise ValueError('Need 2 inputs (y_pred, y), but got {}'.format(len(inputs)))
predict_in = self._convert_data(inputs[0]) predict_in = self._convert_data(inputs[0])
label_in = self._convert_data(inputs[1]) label_in = self._convert_data(inputs[1])
if predict_in.shape[1] != self._num_class: pred = predict_in
raise ValueError('Class number not match, last input data contain {} classes, but current data contain {} '
'classes'.format(self._num_class, predict_in.shape[1]))
pred = np.argmax(predict_in, axis=1)
label = label_in label = label_in
if len(label.flatten()) != len(pred.flatten()): if len(label.flatten()) != len(pred.flatten()):
print('Skipping: len(gt) = {:d}, len(pred) = {:d}'.format(len(label.flatten()), len(pred.flatten()))) print('Skipping: len(gt) = {:d}, len(pred) = {:d}'.format(len(label.flatten()), len(pred.flatten())))