From 65e5a850d231db152eacda993d0e8eedab95cb63 Mon Sep 17 00:00:00 2001 From: wangshuide2020 Date: Thu, 1 Jul 2021 14:54:15 +0800 Subject: [PATCH] fix the tensor validation of loss operators. --- mindspore/nn/loss/loss.py | 60 +++++++++++++++++++-------------------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/mindspore/nn/loss/loss.py b/mindspore/nn/loss/loss.py index 424e7646f2a..23c24e53cbb 100644 --- a/mindspore/nn/loss/loss.py +++ b/mindspore/nn/loss/loss.py @@ -107,9 +107,9 @@ class _Loss(Loss): @constexpr -def _check_input_type(param_name, input_data, allow_dtype, cls_name): - if input_data is not None and not isinstance(input_data, allow_dtype): - raise TypeError(f"For '{cls_name}', the '{param_name}' should be '{allow_dtype}', " +def _check_is_tensor(param_name, input_data, cls_name): + if input_data is not None and not isinstance(F.typeof(input_data), mstype.tensor_type): + raise TypeError(f"For '{cls_name}', the '{param_name}' should be '{mstype.tensor_type}', " f"but got '{F.typeof(input_data)}'") class L1Loss(Loss): @@ -176,8 +176,8 @@ class L1Loss(Loss): self.abs = P.Abs() def construct(self, base, target): - _check_input_type('logits', base, Tensor, self.cls_name) - _check_input_type('labels', target, Tensor, self.cls_name) + _check_is_tensor('logits', base, self.cls_name) + _check_is_tensor('labels', target, self.cls_name) x = self.abs(base - target) return self.get_loss(x) @@ -241,8 +241,8 @@ class MSELoss(Loss): [0. 0. 1.]] """ def construct(self, base, target): - _check_input_type('logits', base, Tensor, self.cls_name) - _check_input_type('labels', target, Tensor, self.cls_name) + _check_is_tensor('logits', base, self.cls_name) + _check_is_tensor('labels', target, self.cls_name) x = F.square(base - target) return self.get_loss(x) @@ -363,8 +363,8 @@ class MAELoss(Loss): self.abs = P.Abs() def construct(self, logits, label): - _check_input_type('logits', logits, Tensor, self.cls_name) - _check_input_type('labels', label, Tensor, self.cls_name) + _check_is_tensor('logits', logits, self.cls_name) + _check_is_tensor('labels', label, self.cls_name) x = self.abs(logits - label) return self.get_loss(x) @@ -429,8 +429,8 @@ class SmoothL1Loss(Loss): self.smooth_l1_loss = P.SmoothL1Loss(self.beta) def construct(self, base, target): - _check_input_type('logits', base, Tensor, self.cls_name) - _check_input_type('labels', target, Tensor, self.cls_name) + _check_is_tensor('logits', base, self.cls_name) + _check_is_tensor('labels', target, self.cls_name) return self.smooth_l1_loss(base, target) @@ -512,8 +512,8 @@ class SoftmaxCrossEntropyWithLogits(Loss): self.sparse_softmax_cross_entropy = P.SparseSoftmaxCrossEntropyWithLogits() def construct(self, logits, labels): - _check_input_type('logits', logits, Tensor, self.cls_name) - _check_input_type('labels', labels, Tensor, self.cls_name) + _check_is_tensor('logits', logits, self.cls_name) + _check_is_tensor('labels', labels, self.cls_name) if self.sparse: if self.reduction == 'mean': x = self.sparse_softmax_cross_entropy(logits, labels) @@ -572,8 +572,8 @@ class DiceLoss(Loss): self.reshape = P.Reshape() def construct(self, logits, label): - _check_input_type('logits', logits, Tensor, self.cls_name) - _check_input_type('labels', label, Tensor, self.cls_name) + _check_is_tensor('logits', logits, self.cls_name) + _check_is_tensor('labels', label, self.cls_name) _check_shape(logits.shape, label.shape) intersection = self.reduce_sum(self.mul(logits.view(-1), label.view(-1))) unionset = self.reduce_sum(self.mul(logits.view(-1), logits.view(-1))) + \ @@ -664,8 +664,8 @@ class MultiClassDiceLoss(Loss): self.reshape = P.Reshape() def construct(self, logits, label): - _check_input_type('logits', logits, Tensor, self.cls_name) - _check_input_type('labels', label, Tensor, self.cls_name) + _check_is_tensor('logits', logits, self.cls_name) + _check_is_tensor('labels', label, self.cls_name) _check_shape(logits.shape, label.shape) _check_ndim_multi(logits.ndim, label.ndim) total_loss = 0 @@ -787,10 +787,10 @@ class SampledSoftmaxLoss(Loss): self.dtype = P.DType() def construct(self, weights, biases, labels, inputs): - _check_input_type('weights', weights, Tensor, self.cls_name) - _check_input_type('biases', biases, Tensor, self.cls_name) - _check_input_type('labels', labels, Tensor, self.cls_name) - _check_input_type('inputs', inputs, Tensor, self.cls_name) + _check_is_tensor('weights', weights, self.cls_name) + _check_is_tensor('biases', biases, self.cls_name) + _check_is_tensor('labels', labels, self.cls_name) + _check_is_tensor('inputs', inputs, self.cls_name) _check_label_dtype(self.dtype(labels), self.cls_name) logits, labels = self._compute_sampled_logits( @@ -972,8 +972,8 @@ class BCELoss(Loss): self.ones = P.OnesLike() def construct(self, inputs, labels): - _check_input_type('logits', inputs, Tensor, self.cls_name) - _check_input_type('labels', labels, Tensor, self.cls_name) + _check_is_tensor('logits', inputs, self.cls_name) + _check_is_tensor('labels', labels, self.cls_name) if self.weight_one: weight = self.ones(inputs) else: @@ -1041,9 +1041,9 @@ class CosineEmbeddingLoss(Loss): self.margin = validator.check_float_range(margin, -1.0, 1.0, Rel.INC_BOTH, "margin", self.cls_name) def construct(self, x1, x2, y): - _check_input_type('logits_x1', x1, Tensor, self.cls_name) - _check_input_type('logits_x2', x2, Tensor, self.cls_name) - _check_input_type('labels', y, Tensor, self.cls_name) + _check_is_tensor('logits_x1', x1, self.cls_name) + _check_is_tensor('logits_x2', x2, self.cls_name) + _check_is_tensor('labels', y, self.cls_name) F.same_type_shape(x1, x2) _check_reduced_shape_valid(F.shape(x1), F.shape(y), (1,), self.cls_name) # if target > 0, 1-cosine(x1, x2) @@ -1137,8 +1137,8 @@ class BCEWithLogitsLoss(Loss): self.ones = P.OnesLike() def construct(self, predict, target): - _check_input_type('logits', predict, Tensor, self.cls_name) - _check_input_type('labels', target, Tensor, self.cls_name) + _check_is_tensor('logits', predict, self.cls_name) + _check_is_tensor('labels', target, self.cls_name) ones_input = self.ones(predict) if self.weight is not None: weight = self.weight @@ -1250,8 +1250,8 @@ class FocalLoss(Loss): self.logsoftmax = nn.LogSoftmax(1) def construct(self, predict, target): - _check_input_type('logits', predict, Tensor, self.cls_name) - _check_input_type('labels', target, Tensor, self.cls_name) + _check_is_tensor('logits', predict, self.cls_name) + _check_is_tensor('labels', target, self.cls_name) targets = target _check_ndim(predict.ndim, targets.ndim) _check_channel_and_shape(predict.shape[1], targets.shape[1])