fix the tensor validation of loss operators.

This commit is contained in:
wangshuide2020 2021-07-01 14:54:15 +08:00
parent d173a88b93
commit 65e5a850d2
1 changed files with 30 additions and 30 deletions

View File

@ -107,9 +107,9 @@ class _Loss(Loss):
@constexpr
def _check_input_type(param_name, input_data, allow_dtype, cls_name):
if input_data is not None and not isinstance(input_data, allow_dtype):
raise TypeError(f"For '{cls_name}', the '{param_name}' should be '{allow_dtype}', "
def _check_is_tensor(param_name, input_data, cls_name):
if input_data is not None and not isinstance(F.typeof(input_data), mstype.tensor_type):
raise TypeError(f"For '{cls_name}', the '{param_name}' should be '{mstype.tensor_type}', "
f"but got '{F.typeof(input_data)}'")
class L1Loss(Loss):
@ -176,8 +176,8 @@ class L1Loss(Loss):
self.abs = P.Abs()
def construct(self, base, target):
_check_input_type('logits', base, Tensor, self.cls_name)
_check_input_type('labels', target, Tensor, self.cls_name)
_check_is_tensor('logits', base, self.cls_name)
_check_is_tensor('labels', target, self.cls_name)
x = self.abs(base - target)
return self.get_loss(x)
@ -241,8 +241,8 @@ class MSELoss(Loss):
[0. 0. 1.]]
"""
def construct(self, base, target):
_check_input_type('logits', base, Tensor, self.cls_name)
_check_input_type('labels', target, Tensor, self.cls_name)
_check_is_tensor('logits', base, self.cls_name)
_check_is_tensor('labels', target, self.cls_name)
x = F.square(base - target)
return self.get_loss(x)
@ -363,8 +363,8 @@ class MAELoss(Loss):
self.abs = P.Abs()
def construct(self, logits, label):
_check_input_type('logits', logits, Tensor, self.cls_name)
_check_input_type('labels', label, Tensor, self.cls_name)
_check_is_tensor('logits', logits, self.cls_name)
_check_is_tensor('labels', label, self.cls_name)
x = self.abs(logits - label)
return self.get_loss(x)
@ -429,8 +429,8 @@ class SmoothL1Loss(Loss):
self.smooth_l1_loss = P.SmoothL1Loss(self.beta)
def construct(self, base, target):
_check_input_type('logits', base, Tensor, self.cls_name)
_check_input_type('labels', target, Tensor, self.cls_name)
_check_is_tensor('logits', base, self.cls_name)
_check_is_tensor('labels', target, self.cls_name)
return self.smooth_l1_loss(base, target)
@ -512,8 +512,8 @@ class SoftmaxCrossEntropyWithLogits(Loss):
self.sparse_softmax_cross_entropy = P.SparseSoftmaxCrossEntropyWithLogits()
def construct(self, logits, labels):
_check_input_type('logits', logits, Tensor, self.cls_name)
_check_input_type('labels', labels, Tensor, self.cls_name)
_check_is_tensor('logits', logits, self.cls_name)
_check_is_tensor('labels', labels, self.cls_name)
if self.sparse:
if self.reduction == 'mean':
x = self.sparse_softmax_cross_entropy(logits, labels)
@ -572,8 +572,8 @@ class DiceLoss(Loss):
self.reshape = P.Reshape()
def construct(self, logits, label):
_check_input_type('logits', logits, Tensor, self.cls_name)
_check_input_type('labels', label, Tensor, self.cls_name)
_check_is_tensor('logits', logits, self.cls_name)
_check_is_tensor('labels', label, self.cls_name)
_check_shape(logits.shape, label.shape)
intersection = self.reduce_sum(self.mul(logits.view(-1), label.view(-1)))
unionset = self.reduce_sum(self.mul(logits.view(-1), logits.view(-1))) + \
@ -664,8 +664,8 @@ class MultiClassDiceLoss(Loss):
self.reshape = P.Reshape()
def construct(self, logits, label):
_check_input_type('logits', logits, Tensor, self.cls_name)
_check_input_type('labels', label, Tensor, self.cls_name)
_check_is_tensor('logits', logits, self.cls_name)
_check_is_tensor('labels', label, self.cls_name)
_check_shape(logits.shape, label.shape)
_check_ndim_multi(logits.ndim, label.ndim)
total_loss = 0
@ -787,10 +787,10 @@ class SampledSoftmaxLoss(Loss):
self.dtype = P.DType()
def construct(self, weights, biases, labels, inputs):
_check_input_type('weights', weights, Tensor, self.cls_name)
_check_input_type('biases', biases, Tensor, self.cls_name)
_check_input_type('labels', labels, Tensor, self.cls_name)
_check_input_type('inputs', inputs, Tensor, self.cls_name)
_check_is_tensor('weights', weights, self.cls_name)
_check_is_tensor('biases', biases, self.cls_name)
_check_is_tensor('labels', labels, self.cls_name)
_check_is_tensor('inputs', inputs, self.cls_name)
_check_label_dtype(self.dtype(labels), self.cls_name)
logits, labels = self._compute_sampled_logits(
@ -972,8 +972,8 @@ class BCELoss(Loss):
self.ones = P.OnesLike()
def construct(self, inputs, labels):
_check_input_type('logits', inputs, Tensor, self.cls_name)
_check_input_type('labels', labels, Tensor, self.cls_name)
_check_is_tensor('logits', inputs, self.cls_name)
_check_is_tensor('labels', labels, self.cls_name)
if self.weight_one:
weight = self.ones(inputs)
else:
@ -1041,9 +1041,9 @@ class CosineEmbeddingLoss(Loss):
self.margin = validator.check_float_range(margin, -1.0, 1.0, Rel.INC_BOTH, "margin", self.cls_name)
def construct(self, x1, x2, y):
_check_input_type('logits_x1', x1, Tensor, self.cls_name)
_check_input_type('logits_x2', x2, Tensor, self.cls_name)
_check_input_type('labels', y, Tensor, self.cls_name)
_check_is_tensor('logits_x1', x1, self.cls_name)
_check_is_tensor('logits_x2', x2, self.cls_name)
_check_is_tensor('labels', y, self.cls_name)
F.same_type_shape(x1, x2)
_check_reduced_shape_valid(F.shape(x1), F.shape(y), (1,), self.cls_name)
# if target > 0, 1-cosine(x1, x2)
@ -1137,8 +1137,8 @@ class BCEWithLogitsLoss(Loss):
self.ones = P.OnesLike()
def construct(self, predict, target):
_check_input_type('logits', predict, Tensor, self.cls_name)
_check_input_type('labels', target, Tensor, self.cls_name)
_check_is_tensor('logits', predict, self.cls_name)
_check_is_tensor('labels', target, self.cls_name)
ones_input = self.ones(predict)
if self.weight is not None:
weight = self.weight
@ -1250,8 +1250,8 @@ class FocalLoss(Loss):
self.logsoftmax = nn.LogSoftmax(1)
def construct(self, predict, target):
_check_input_type('logits', predict, Tensor, self.cls_name)
_check_input_type('labels', target, Tensor, self.cls_name)
_check_is_tensor('logits', predict, self.cls_name)
_check_is_tensor('labels', target, self.cls_name)
targets = target
_check_ndim(predict.ndim, targets.ndim)
_check_channel_and_shape(predict.shape[1], targets.shape[1])