fix the tensor validation of loss operators.
This commit is contained in:
parent
d173a88b93
commit
65e5a850d2
|
@ -107,9 +107,9 @@ class _Loss(Loss):
|
|||
|
||||
|
||||
@constexpr
|
||||
def _check_input_type(param_name, input_data, allow_dtype, cls_name):
|
||||
if input_data is not None and not isinstance(input_data, allow_dtype):
|
||||
raise TypeError(f"For '{cls_name}', the '{param_name}' should be '{allow_dtype}', "
|
||||
def _check_is_tensor(param_name, input_data, cls_name):
|
||||
if input_data is not None and not isinstance(F.typeof(input_data), mstype.tensor_type):
|
||||
raise TypeError(f"For '{cls_name}', the '{param_name}' should be '{mstype.tensor_type}', "
|
||||
f"but got '{F.typeof(input_data)}'")
|
||||
|
||||
class L1Loss(Loss):
|
||||
|
@ -176,8 +176,8 @@ class L1Loss(Loss):
|
|||
self.abs = P.Abs()
|
||||
|
||||
def construct(self, base, target):
|
||||
_check_input_type('logits', base, Tensor, self.cls_name)
|
||||
_check_input_type('labels', target, Tensor, self.cls_name)
|
||||
_check_is_tensor('logits', base, self.cls_name)
|
||||
_check_is_tensor('labels', target, self.cls_name)
|
||||
x = self.abs(base - target)
|
||||
return self.get_loss(x)
|
||||
|
||||
|
@ -241,8 +241,8 @@ class MSELoss(Loss):
|
|||
[0. 0. 1.]]
|
||||
"""
|
||||
def construct(self, base, target):
|
||||
_check_input_type('logits', base, Tensor, self.cls_name)
|
||||
_check_input_type('labels', target, Tensor, self.cls_name)
|
||||
_check_is_tensor('logits', base, self.cls_name)
|
||||
_check_is_tensor('labels', target, self.cls_name)
|
||||
x = F.square(base - target)
|
||||
return self.get_loss(x)
|
||||
|
||||
|
@ -363,8 +363,8 @@ class MAELoss(Loss):
|
|||
self.abs = P.Abs()
|
||||
|
||||
def construct(self, logits, label):
|
||||
_check_input_type('logits', logits, Tensor, self.cls_name)
|
||||
_check_input_type('labels', label, Tensor, self.cls_name)
|
||||
_check_is_tensor('logits', logits, self.cls_name)
|
||||
_check_is_tensor('labels', label, self.cls_name)
|
||||
x = self.abs(logits - label)
|
||||
return self.get_loss(x)
|
||||
|
||||
|
@ -429,8 +429,8 @@ class SmoothL1Loss(Loss):
|
|||
self.smooth_l1_loss = P.SmoothL1Loss(self.beta)
|
||||
|
||||
def construct(self, base, target):
|
||||
_check_input_type('logits', base, Tensor, self.cls_name)
|
||||
_check_input_type('labels', target, Tensor, self.cls_name)
|
||||
_check_is_tensor('logits', base, self.cls_name)
|
||||
_check_is_tensor('labels', target, self.cls_name)
|
||||
return self.smooth_l1_loss(base, target)
|
||||
|
||||
|
||||
|
@ -512,8 +512,8 @@ class SoftmaxCrossEntropyWithLogits(Loss):
|
|||
self.sparse_softmax_cross_entropy = P.SparseSoftmaxCrossEntropyWithLogits()
|
||||
|
||||
def construct(self, logits, labels):
|
||||
_check_input_type('logits', logits, Tensor, self.cls_name)
|
||||
_check_input_type('labels', labels, Tensor, self.cls_name)
|
||||
_check_is_tensor('logits', logits, self.cls_name)
|
||||
_check_is_tensor('labels', labels, self.cls_name)
|
||||
if self.sparse:
|
||||
if self.reduction == 'mean':
|
||||
x = self.sparse_softmax_cross_entropy(logits, labels)
|
||||
|
@ -572,8 +572,8 @@ class DiceLoss(Loss):
|
|||
self.reshape = P.Reshape()
|
||||
|
||||
def construct(self, logits, label):
|
||||
_check_input_type('logits', logits, Tensor, self.cls_name)
|
||||
_check_input_type('labels', label, Tensor, self.cls_name)
|
||||
_check_is_tensor('logits', logits, self.cls_name)
|
||||
_check_is_tensor('labels', label, self.cls_name)
|
||||
_check_shape(logits.shape, label.shape)
|
||||
intersection = self.reduce_sum(self.mul(logits.view(-1), label.view(-1)))
|
||||
unionset = self.reduce_sum(self.mul(logits.view(-1), logits.view(-1))) + \
|
||||
|
@ -664,8 +664,8 @@ class MultiClassDiceLoss(Loss):
|
|||
self.reshape = P.Reshape()
|
||||
|
||||
def construct(self, logits, label):
|
||||
_check_input_type('logits', logits, Tensor, self.cls_name)
|
||||
_check_input_type('labels', label, Tensor, self.cls_name)
|
||||
_check_is_tensor('logits', logits, self.cls_name)
|
||||
_check_is_tensor('labels', label, self.cls_name)
|
||||
_check_shape(logits.shape, label.shape)
|
||||
_check_ndim_multi(logits.ndim, label.ndim)
|
||||
total_loss = 0
|
||||
|
@ -787,10 +787,10 @@ class SampledSoftmaxLoss(Loss):
|
|||
self.dtype = P.DType()
|
||||
|
||||
def construct(self, weights, biases, labels, inputs):
|
||||
_check_input_type('weights', weights, Tensor, self.cls_name)
|
||||
_check_input_type('biases', biases, Tensor, self.cls_name)
|
||||
_check_input_type('labels', labels, Tensor, self.cls_name)
|
||||
_check_input_type('inputs', inputs, Tensor, self.cls_name)
|
||||
_check_is_tensor('weights', weights, self.cls_name)
|
||||
_check_is_tensor('biases', biases, self.cls_name)
|
||||
_check_is_tensor('labels', labels, self.cls_name)
|
||||
_check_is_tensor('inputs', inputs, self.cls_name)
|
||||
_check_label_dtype(self.dtype(labels), self.cls_name)
|
||||
|
||||
logits, labels = self._compute_sampled_logits(
|
||||
|
@ -972,8 +972,8 @@ class BCELoss(Loss):
|
|||
self.ones = P.OnesLike()
|
||||
|
||||
def construct(self, inputs, labels):
|
||||
_check_input_type('logits', inputs, Tensor, self.cls_name)
|
||||
_check_input_type('labels', labels, Tensor, self.cls_name)
|
||||
_check_is_tensor('logits', inputs, self.cls_name)
|
||||
_check_is_tensor('labels', labels, self.cls_name)
|
||||
if self.weight_one:
|
||||
weight = self.ones(inputs)
|
||||
else:
|
||||
|
@ -1041,9 +1041,9 @@ class CosineEmbeddingLoss(Loss):
|
|||
self.margin = validator.check_float_range(margin, -1.0, 1.0, Rel.INC_BOTH, "margin", self.cls_name)
|
||||
|
||||
def construct(self, x1, x2, y):
|
||||
_check_input_type('logits_x1', x1, Tensor, self.cls_name)
|
||||
_check_input_type('logits_x2', x2, Tensor, self.cls_name)
|
||||
_check_input_type('labels', y, Tensor, self.cls_name)
|
||||
_check_is_tensor('logits_x1', x1, self.cls_name)
|
||||
_check_is_tensor('logits_x2', x2, self.cls_name)
|
||||
_check_is_tensor('labels', y, self.cls_name)
|
||||
F.same_type_shape(x1, x2)
|
||||
_check_reduced_shape_valid(F.shape(x1), F.shape(y), (1,), self.cls_name)
|
||||
# if target > 0, 1-cosine(x1, x2)
|
||||
|
@ -1137,8 +1137,8 @@ class BCEWithLogitsLoss(Loss):
|
|||
self.ones = P.OnesLike()
|
||||
|
||||
def construct(self, predict, target):
|
||||
_check_input_type('logits', predict, Tensor, self.cls_name)
|
||||
_check_input_type('labels', target, Tensor, self.cls_name)
|
||||
_check_is_tensor('logits', predict, self.cls_name)
|
||||
_check_is_tensor('labels', target, self.cls_name)
|
||||
ones_input = self.ones(predict)
|
||||
if self.weight is not None:
|
||||
weight = self.weight
|
||||
|
@ -1250,8 +1250,8 @@ class FocalLoss(Loss):
|
|||
self.logsoftmax = nn.LogSoftmax(1)
|
||||
|
||||
def construct(self, predict, target):
|
||||
_check_input_type('logits', predict, Tensor, self.cls_name)
|
||||
_check_input_type('labels', target, Tensor, self.cls_name)
|
||||
_check_is_tensor('logits', predict, self.cls_name)
|
||||
_check_is_tensor('labels', target, self.cls_name)
|
||||
targets = target
|
||||
_check_ndim(predict.ndim, targets.ndim)
|
||||
_check_channel_and_shape(predict.shape[1], targets.shape[1])
|
||||
|
|
Loading…
Reference in New Issue