forked from mindspore-Ecosystem/mindspore
add type validation for loss operators.
This commit is contained in:
parent
d551162f5b
commit
ac0bc6a38c
|
@ -79,6 +79,11 @@ class _Loss(Cell):
|
|||
def construct(self, base, target):
|
||||
raise NotImplementedError
|
||||
|
||||
@constexpr
|
||||
def _check_input_type(param_name, input_data, allow_dtype, cls_name):
|
||||
if input_data is not None and not isinstance(input_data, allow_dtype):
|
||||
raise TypeError(f"For '{cls_name}', the '{param_name}' should be '{allow_dtype}', "
|
||||
f"but got '{F.typeof(input_data)}'")
|
||||
|
||||
class L1Loss(_Loss):
|
||||
r"""
|
||||
|
@ -99,8 +104,8 @@ class L1Loss(_Loss):
|
|||
Default: "mean".
|
||||
|
||||
Inputs:
|
||||
- **input_data** (Tensor) - Tensor of shape :math:`(x_1, x_2, ..., x_R)`.
|
||||
- **target_data** (Tensor) - Tensor of shape :math:`(y_1, y_2, ..., y_S)`.
|
||||
- **logits** (Tensor) - Tensor of shape :math:`(x_1, x_2, ..., x_R)`.
|
||||
- **labels** (Tensor) - Tensor of shape :math:`(y_1, y_2, ..., y_S)`.
|
||||
|
||||
Outputs:
|
||||
Tensor, loss float tensor.
|
||||
|
@ -113,9 +118,9 @@ class L1Loss(_Loss):
|
|||
|
||||
Examples:
|
||||
>>> loss = nn.L1Loss()
|
||||
>>> input_data = Tensor(np.array([1, 2, 3]), mindspore.float32)
|
||||
>>> target_data = Tensor(np.array([1, 2, 2]), mindspore.float32)
|
||||
>>> output = loss(input_data, target_data)
|
||||
>>> logits = Tensor(np.array([1, 2, 3]), mindspore.float32)
|
||||
>>> labels = Tensor(np.array([1, 2, 2]), mindspore.float32)
|
||||
>>> output = loss(logits, labels)
|
||||
>>> print(output)
|
||||
0.33333334
|
||||
"""
|
||||
|
@ -124,6 +129,8 @@ class L1Loss(_Loss):
|
|||
self.abs = P.Abs()
|
||||
|
||||
def construct(self, base, target):
|
||||
_check_input_type('logits', base, Tensor, self.cls_name)
|
||||
_check_input_type('labels', target, Tensor, self.cls_name)
|
||||
x = self.abs(base - target)
|
||||
return self.get_loss(x)
|
||||
|
||||
|
@ -147,8 +154,8 @@ class MSELoss(_Loss):
|
|||
Default: "mean".
|
||||
|
||||
Inputs:
|
||||
- **input_data** (Tensor) - Tensor of shape :math:`(x_1, x_2, ..., x_R)`.
|
||||
- **target_data** (Tensor) - Tensor of shape :math:`(y_1, y_2, ..., y_S)`.
|
||||
- **logits** (Tensor) - Tensor of shape :math:`(x_1, x_2, ..., x_R)`.
|
||||
- **labels** (Tensor) - Tensor of shape :math:`(y_1, y_2, ..., y_S)`.
|
||||
|
||||
Outputs:
|
||||
Tensor, weighted loss float tensor.
|
||||
|
@ -161,13 +168,15 @@ class MSELoss(_Loss):
|
|||
|
||||
Examples:
|
||||
>>> loss = nn.MSELoss()
|
||||
>>> input_data = Tensor(np.array([1, 2, 3]), mindspore.float32)
|
||||
>>> target_data = Tensor(np.array([1, 2, 2]), mindspore.float32)
|
||||
>>> output = loss(input_data, target_data)
|
||||
>>> logits = Tensor(np.array([1, 2, 3]), mindspore.float32)
|
||||
>>> labels = Tensor(np.array([1, 2, 2]), mindspore.float32)
|
||||
>>> output = loss(logits, labels)
|
||||
>>> print(output)
|
||||
0.33333334
|
||||
"""
|
||||
def construct(self, base, target):
|
||||
_check_input_type('logits', base, Tensor, self.cls_name)
|
||||
_check_input_type('labels', target, Tensor, self.cls_name)
|
||||
x = F.square(base - target)
|
||||
return self.get_loss(x)
|
||||
|
||||
|
@ -187,7 +196,7 @@ class RMSELoss(_Loss):
|
|||
|
||||
Inputs:
|
||||
- **logits** (Tensor) - Tensor of shape :math:`(x_1, x_2, ..., x_M)`.
|
||||
- **label** (Tensor) - Tensor of shape :math:`(y_1, y_2, ..., y_N)`.
|
||||
- **labels** (Tensor) - Tensor of shape :math:`(y_1, y_2, ..., y_N)`.
|
||||
|
||||
Outputs:
|
||||
Tensor, weighted loss float tensor.
|
||||
|
@ -197,9 +206,9 @@ class RMSELoss(_Loss):
|
|||
|
||||
Examples:
|
||||
>>> loss = nn.RMSELoss()
|
||||
>>> input_data = Tensor(np.array([1, 2, 3]), mindspore.float32)
|
||||
>>> target_data = Tensor(np.array([1, 2, 2]), mindspore.float32)
|
||||
>>> output = loss(input_data, target_data)
|
||||
>>> logits = Tensor(np.array([1, 2, 3]), mindspore.float32)
|
||||
>>> labels = Tensor(np.array([1, 2, 2]), mindspore.float32)
|
||||
>>> output = loss(logits, labels)
|
||||
>>> print(output)
|
||||
0.57735026
|
||||
"""
|
||||
|
@ -231,7 +240,7 @@ class MAELoss(_Loss):
|
|||
|
||||
Inputs:
|
||||
- **logits** (Tensor) - Tensor of shape :math:`(x_1, x_2, ..., x_M)`.
|
||||
- **label** (Tensor) - Tensor of shape :math:`(y_1, y_2, ..., y_N)`.
|
||||
- **labels** (Tensor) - Tensor of shape :math:`(y_1, y_2, ..., y_N)`.
|
||||
|
||||
Outputs:
|
||||
Tensor, weighted loss float tensor.
|
||||
|
@ -244,9 +253,9 @@ class MAELoss(_Loss):
|
|||
|
||||
Examples:
|
||||
>>> loss = nn.MAELoss()
|
||||
>>> input_data = Tensor(np.array([1, 2, 3]), mindspore.float32)
|
||||
>>> target_data = Tensor(np.array([1, 2, 2]), mindspore.float32)
|
||||
>>> output = loss(input_data, target_data)
|
||||
>>> logits = Tensor(np.array([1, 2, 3]), mindspore.float32)
|
||||
>>> labels = Tensor(np.array([1, 2, 2]), mindspore.float32)
|
||||
>>> output = loss(logits, labels)
|
||||
>>> print(output)
|
||||
0.33333334
|
||||
"""
|
||||
|
@ -255,6 +264,8 @@ class MAELoss(_Loss):
|
|||
self.abs = P.Abs()
|
||||
|
||||
def construct(self, logits, label):
|
||||
_check_input_type('logits', logits, Tensor, self.cls_name)
|
||||
_check_input_type('labels', label, Tensor, self.cls_name)
|
||||
x = self.abs(logits - label)
|
||||
return self.get_loss(x)
|
||||
|
||||
|
@ -287,26 +298,26 @@ class SmoothL1Loss(_Loss):
|
|||
quadratic to linear. Default: 1.0.
|
||||
|
||||
Inputs:
|
||||
- **input_data** (Tensor) - Tensor of shape :math:`(x_1, x_2, ..., x_R)`. Data type must be float16 or float32.
|
||||
- **target_data** (Tensor) - Ground truth data, with the same type and shape as `input_data`.
|
||||
- **logits** (Tensor) - Tensor of shape :math:`(x_1, x_2, ..., x_R)`. Data type must be float16 or float32.
|
||||
- **labels** (Tensor) - Ground truth data, with the same type and shape as `logits`.
|
||||
|
||||
Outputs:
|
||||
Tensor, loss float tensor.
|
||||
|
||||
Raises:
|
||||
TypeError: If `beta` is not a float.
|
||||
TypeError: If dtype of `input_data` or `target_data` is neither float16 not float32.
|
||||
TypeError: If dtype of `logits` or `labels` is neither float16 not float32.
|
||||
ValueError: If `beta` is less than or equal to 0.
|
||||
ValueError: If shape of `input_data` is not the same as `target_data`.
|
||||
ValueError: If shape of `logits` is not the same as `labels`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> loss = nn.SmoothL1Loss()
|
||||
>>> input_data = Tensor(np.array([1, 2, 3]), mindspore.float32)
|
||||
>>> target_data = Tensor(np.array([1, 2, 2]), mindspore.float32)
|
||||
>>> output = loss(input_data, target_data)
|
||||
>>> logits = Tensor(np.array([1, 2, 3]), mindspore.float32)
|
||||
>>> labels = Tensor(np.array([1, 2, 2]), mindspore.float32)
|
||||
>>> output = loss(logits, labels)
|
||||
>>> print(output)
|
||||
[0. 0. 0.5]
|
||||
"""
|
||||
|
@ -316,6 +327,8 @@ class SmoothL1Loss(_Loss):
|
|||
self.smooth_l1_loss = P.SmoothL1Loss(self.beta)
|
||||
|
||||
def construct(self, base, target):
|
||||
_check_input_type('logits', base, Tensor, self.cls_name)
|
||||
_check_input_type('labels', target, Tensor, self.cls_name)
|
||||
return self.smooth_l1_loss(base, target)
|
||||
|
||||
|
||||
|
@ -388,6 +401,8 @@ class SoftmaxCrossEntropyWithLogits(_Loss):
|
|||
self.sparse_softmax_cross_entropy = P.SparseSoftmaxCrossEntropyWithLogits()
|
||||
|
||||
def construct(self, logits, labels):
|
||||
_check_input_type('logits', logits, Tensor, self.cls_name)
|
||||
_check_input_type('labels', labels, Tensor, self.cls_name)
|
||||
if self.sparse:
|
||||
if self.reduction == 'mean':
|
||||
x = self.sparse_softmax_cross_entropy(logits, labels)
|
||||
|
@ -416,24 +431,24 @@ class DiceLoss(_Loss):
|
|||
Default: 1e-5.
|
||||
|
||||
Inputs:
|
||||
- **y_pred** (Tensor) - Tensor of shape (N, ...). The data type must be float16 or float32.
|
||||
- **y** (Tensor) - Tensor of shape (N, ...). The data type must be float16 or float32.
|
||||
- **logits** (Tensor) - Tensor of shape (N, ...). The data type must be float16 or float32.
|
||||
- **labels** (Tensor) - Tensor of shape (N, ...). The data type must be float16 or float32.
|
||||
|
||||
Outputs:
|
||||
Tensor, a tensor of shape with the per-example sampled Dice losses.
|
||||
|
||||
Raises:
|
||||
ValueError: If the dimensions are different.
|
||||
TypeError: If the type of inputs are not Tensor.
|
||||
TypeError: If the type of `logits` or `labels` are not Tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> loss = nn.DiceLoss(smooth=1e-5)
|
||||
>>> y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32)
|
||||
>>> y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]), mstype.float32)
|
||||
>>> output = loss(y_pred, y)
|
||||
>>> logits = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32)
|
||||
>>> labels = Tensor(np.array([[0, 1], [1, 0], [0, 1]]), mstype.float32)
|
||||
>>> output = loss(logits, labels)
|
||||
>>> print(output)
|
||||
0.38596618
|
||||
"""
|
||||
|
@ -443,6 +458,8 @@ class DiceLoss(_Loss):
|
|||
self.reshape = P.Reshape()
|
||||
|
||||
def construct(self, logits, label):
|
||||
_check_input_type('logits', logits, Tensor, self.cls_name)
|
||||
_check_input_type('labels', label, Tensor, self.cls_name)
|
||||
_check_shape(logits.shape, label.shape)
|
||||
intersection = self.reduce_sum(self.mul(logits.view(-1), label.view(-1)))
|
||||
unionset = self.reduce_sum(self.mul(logits.view(-1), logits.view(-1))) + \
|
||||
|
@ -487,10 +504,10 @@ class MultiClassDiceLoss(_Loss):
|
|||
Default: 'softmax'. Choose from: ['softmax', 'logsoftmax', 'relu', 'relu6', 'tanh','Sigmoid']
|
||||
|
||||
Inputs:
|
||||
- **y_pred** (Tensor) - Tensor of shape (N, C, ...). The y_pred dimension should be greater than 1. The data
|
||||
- **logits** (Tensor) - Tensor of shape (N, C, ...). The logits dimension should be greater than 1. The data
|
||||
type must be float16 or float32.
|
||||
- **y** (Tensor) - Tensor of shape (N, C, ...). The y dimension should be greater than 1. The data type must be
|
||||
loat16 or float32.
|
||||
- **labels** (Tensor) - Tensor of shape (N, C, ...). The labels dimension should be greater than 1.
|
||||
The data type must be loat16 or float32.
|
||||
|
||||
Outputs:
|
||||
Tensor, a tensor of shape with the per-example sampled MultiClass Dice Losses.
|
||||
|
@ -498,8 +515,8 @@ class MultiClassDiceLoss(_Loss):
|
|||
Raises:
|
||||
ValueError: If the shapes are different.
|
||||
TypeError: If the type of inputs are not Tensor.
|
||||
ValueError: If the dimension of y or y_pred is less than 2.
|
||||
ValueError: If the weight shape[0] is not equal to y.shape[1].
|
||||
ValueError: If the dimension of `logits` or `labels` is less than 2.
|
||||
ValueError: If the weight shape[0] is not equal to labels.shape[1].
|
||||
ValueError: If weight is a tensor, but the dimension is not 2.
|
||||
|
||||
Supported Platforms:
|
||||
|
@ -507,9 +524,9 @@ class MultiClassDiceLoss(_Loss):
|
|||
|
||||
Examples:
|
||||
>>> loss = nn.MultiClassDiceLoss(weights=None, ignore_indiex=None, activation="softmax")
|
||||
>>> y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32)
|
||||
>>> y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]), mstype.float32)
|
||||
>>> output = loss(y_pred, y)
|
||||
>>> logits = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32)
|
||||
>>> labels = Tensor(np.array([[0, 1], [1, 0], [0, 1]]), mstype.float32)
|
||||
>>> output = loss(logits, labels)
|
||||
>>> print(output)
|
||||
0.3283009
|
||||
"""
|
||||
|
@ -532,6 +549,8 @@ class MultiClassDiceLoss(_Loss):
|
|||
self.reshape = P.Reshape()
|
||||
|
||||
def construct(self, logits, label):
|
||||
_check_input_type('logits', logits, Tensor, self.cls_name)
|
||||
_check_input_type('labels', label, Tensor, self.cls_name)
|
||||
_check_shape(logits.shape, label.shape)
|
||||
_check_ndim_multi(logits.ndim, label.ndim)
|
||||
total_loss = 0
|
||||
|
@ -652,6 +671,10 @@ class SampledSoftmaxLoss(_Loss):
|
|||
self.dtype = P.DType()
|
||||
|
||||
def construct(self, weights, biases, labels, inputs):
|
||||
_check_input_type('weights', weights, Tensor, self.cls_name)
|
||||
_check_input_type('biases', biases, Tensor, self.cls_name)
|
||||
_check_input_type('labels', labels, Tensor, self.cls_name)
|
||||
_check_input_type('inputs', inputs, Tensor, self.cls_name)
|
||||
_check_label_dtype(self.dtype(labels), self.cls_name)
|
||||
|
||||
logits, labels = self._compute_sampled_logits(
|
||||
|
@ -800,17 +823,17 @@ class BCELoss(_Loss):
|
|||
Its value must be one of 'none', 'mean', 'sum'. Default: 'none'.
|
||||
|
||||
Inputs:
|
||||
- **inputs** (Tensor) - The input Tensor. The data type must be float16 or float32.
|
||||
- **labels** (Tensor) - The label Tensor which has same shape and data type as `inputs`.
|
||||
- **logits** (Tensor) - The input Tensor. The data type must be float16 or float32.
|
||||
- **labels** (Tensor) - The label Tensor which has same shape and data type as `logits`.
|
||||
|
||||
Outputs:
|
||||
Tensor or Scalar, if `reduction` is 'none', then output is a tensor and has the same shape as `inputs`.
|
||||
Tensor or Scalar, if `reduction` is 'none', then output is a tensor and has the same shape as `logits`.
|
||||
Otherwise, the output is a scalar.
|
||||
|
||||
Raises:
|
||||
TypeError: If dtype of `inputs`, `labels` or `weight` (if given) is neither float16 not float32.
|
||||
TypeError: If dtype of `logits`, `labels` or `weight` (if given) is neither float16 not float32.
|
||||
ValueError: If `reduction` is not one of 'none', 'mean', 'sum'.
|
||||
ValueError: If shape of `inputs` is not the same as `labels` or `weight` (if given).
|
||||
ValueError: If shape of `logits` is not the same as `labels` or `weight` (if given).
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
@ -818,9 +841,9 @@ class BCELoss(_Loss):
|
|||
Examples:
|
||||
>>> weight = Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 3.3, 2.2]]), mindspore.float32)
|
||||
>>> loss = nn.BCELoss(weight=weight, reduction='mean')
|
||||
>>> inputs = Tensor(np.array([[0.1, 0.2, 0.3], [0.5, 0.7, 0.9]]), mindspore.float32)
|
||||
>>> logits = Tensor(np.array([[0.1, 0.2, 0.3], [0.5, 0.7, 0.9]]), mindspore.float32)
|
||||
>>> labels = Tensor(np.array([[0, 1, 0], [0, 0, 1]]), mindspore.float32)
|
||||
>>> output = loss(inputs, labels)
|
||||
>>> output = loss(logits, labels)
|
||||
>>> print(output)
|
||||
1.8952923
|
||||
"""
|
||||
|
@ -835,6 +858,8 @@ class BCELoss(_Loss):
|
|||
self.ones = P.OnesLike()
|
||||
|
||||
def construct(self, inputs, labels):
|
||||
_check_input_type('logits', inputs, Tensor, self.cls_name)
|
||||
_check_input_type('labels', labels, Tensor, self.cls_name)
|
||||
if self.weight_one:
|
||||
weight = self.ones(inputs)
|
||||
else:
|
||||
|
@ -866,14 +891,14 @@ class CosineEmbeddingLoss(_Loss):
|
|||
"none", "mean", and "sum", meaning no reduction, reduce mean and sum on output, respectively. Default "mean".
|
||||
|
||||
Inputs:
|
||||
- **input_x1** (Tensor) - Input tensor.
|
||||
- **input_x2** (Tensor) - Its shape and data type must be the same as `input_x1`'s shape and data type.
|
||||
- **y** (Tensor) - Contains value 1 or -1. Suppose the shape of `input_x1` is
|
||||
:math:`(x_1, x_2, x_3,..., x_R)`, then the shape of `target` must be :math:`(x_1, x_3, x_4, ..., x_R)`.
|
||||
- **logits_x1** (Tensor) - Input tensor.
|
||||
- **logits_x2** (Tensor) - Its shape and data type must be the same as `logits_x1`'s shape and data type.
|
||||
- **labels** (Tensor) - Contains value 1 or -1. Suppose the shape of `logits_x1` is
|
||||
:math:`(x_1, x_2, x_3,..., x_R)`, then the shape of `labels` must be :math:`(x_1, x_3, x_4, ..., x_R)`.
|
||||
|
||||
Outputs:
|
||||
- **loss** (Tensor) - If `reduction` is "none", its shape is the same as `y`'s shape, otherwise a scalar value
|
||||
will be returned.
|
||||
- **loss** (Tensor) - If `reduction` is "none", its shape is the same as `labels`'s shape,
|
||||
otherwise a scalar value will be returned.
|
||||
|
||||
Raises:
|
||||
TypeError: If `margin` is not a float.
|
||||
|
@ -884,11 +909,11 @@ class CosineEmbeddingLoss(_Loss):
|
|||
``Ascend`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> x1 = Tensor(np.array([[0.3, 0.8], [0.4, 0.3]]), mindspore.float32)
|
||||
>>> x2 = Tensor(np.array([[0.4, 1.2], [-0.4, -0.9]]), mindspore.float32)
|
||||
>>> y = Tensor(np.array([1, -1]), mindspore.int32)
|
||||
>>> logits_x1 = Tensor(np.array([[0.3, 0.8], [0.4, 0.3]]), mindspore.float32)
|
||||
>>> logits_x2 = Tensor(np.array([[0.4, 1.2], [-0.4, -0.9]]), mindspore.float32)
|
||||
>>> labels = Tensor(np.array([1, -1]), mindspore.int32)
|
||||
>>> cosine_embedding_loss = nn.CosineEmbeddingLoss()
|
||||
>>> output = cosine_embedding_loss(x1, x2, y)
|
||||
>>> output = cosine_embedding_loss(logits_x1, logits_x2, labels)
|
||||
>>> print(output)
|
||||
0.0003426075
|
||||
"""
|
||||
|
@ -900,6 +925,9 @@ class CosineEmbeddingLoss(_Loss):
|
|||
self.margin = validator.check_float_range(margin, -1.0, 1.0, Rel.INC_BOTH, "margin", self.cls_name)
|
||||
|
||||
def construct(self, x1, x2, y):
|
||||
_check_input_type('logits_x1', x1, Tensor, self.cls_name)
|
||||
_check_input_type('logits_x2', x2, Tensor, self.cls_name)
|
||||
_check_input_type('labels', y, Tensor, self.cls_name)
|
||||
F.same_type_shape(x1, x2)
|
||||
_check_reduced_shape_valid(F.shape(x1), F.shape(y), (1,), self.cls_name)
|
||||
# if target > 0, 1-cosine(x1, x2)
|
||||
|
@ -953,27 +981,27 @@ class BCEWithLogitsLoss(_Loss):
|
|||
data type must be float16 or float32. Default: None.
|
||||
|
||||
Inputs:
|
||||
- **predict** (Tensor) - Input logits. The data type must be float16 or float32.
|
||||
- **target** (Tensor) - Ground truth label. Has the same data type and shape with `predict`.
|
||||
- **logits** (Tensor) - Input logits. The data type must be float16 or float32.
|
||||
- **labels** (Tensor) - Ground truth label. Has the same data type and shape with `logits`.
|
||||
|
||||
Outputs:
|
||||
Scalar. If reduction is 'none', it's a tensor with the same shape and type as input `predict`.
|
||||
Scalar. If reduction is 'none', it's a tensor with the same shape and type as input `logits`.
|
||||
|
||||
Raises:
|
||||
TypeError: If data type of `predict` or `target` is neither float16 nor float32.
|
||||
TypeError: If data type of `logits` or `labels` is neither float16 nor float32.
|
||||
TypeError: If `weight` or `pos_weight` is Parameter.
|
||||
TypeError: If data type of `weight` or `pos_weight` is neither float16 nor float32.
|
||||
ValueError: If `weight` or `pos_weight` can not be broadcast to a tensor with shape of `predict`.
|
||||
ValueError: If `weight` or `pos_weight` can not be broadcast to a tensor with shape of `logits`.
|
||||
ValueError: If `reduction` is not one of 'none', 'mean', 'sum'.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
|
||||
Examples:
|
||||
>>> predict = Tensor(np.array([[-0.8, 1.2, 0.7], [-0.1, -0.4, 0.7]]).astype(np.float32))
|
||||
>>> target = Tensor(np.array([[0.3, 0.8, 1.2], [-0.6, 0.1, 2.2]]).astype(np.float32))
|
||||
>>> logits = Tensor(np.array([[-0.8, 1.2, 0.7], [-0.1, -0.4, 0.7]]).astype(np.float32))
|
||||
>>> labels = Tensor(np.array([[0.3, 0.8, 1.2], [-0.6, 0.1, 2.2]]).astype(np.float32))
|
||||
>>> loss = nn.BCEWithLogitsLoss()
|
||||
>>> output = loss(predict, target)
|
||||
>>> output = loss(logits, labels)
|
||||
>>> print(output)
|
||||
0.3463612
|
||||
"""
|
||||
|
@ -990,6 +1018,8 @@ class BCEWithLogitsLoss(_Loss):
|
|||
self.ones = P.OnesLike()
|
||||
|
||||
def construct(self, predict, target):
|
||||
_check_input_type('logits', predict, Tensor, self.cls_name)
|
||||
_check_input_type('labels', target, Tensor, self.cls_name)
|
||||
ones_input = self.ones(predict)
|
||||
if self.weight is not None:
|
||||
weight = self.weight
|
||||
|
@ -1050,32 +1080,32 @@ class FocalLoss(_Loss):
|
|||
If "none", do not perform reduction. Default: "mean".
|
||||
|
||||
Inputs:
|
||||
- **predict** (Tensor) - Tensor of shape should be (B, C) or (B, C, H) or (B, C, H, W). Where C is the number
|
||||
- **logits** (Tensor) - Tensor of shape should be (B, C) or (B, C, H) or (B, C, H, W). Where C is the number
|
||||
of classes. Its value is greater than 1. If the shape is (B, C, H, W) or (B, C, H), the H or product of H
|
||||
and W should be the same as target.
|
||||
- **target** (Tensor) - Tensor of shape should be (B, C) or (B, C, H) or (B, C, H, W). The value of C is 1 or
|
||||
and W should be the same as labels.
|
||||
- **labels** (Tensor) - Tensor of shape should be (B, C) or (B, C, H) or (B, C, H, W). The value of C is 1 or
|
||||
it needs to be the same as predict's C. If C is not 1, the shape of target should be the same as that of
|
||||
predict, where C is the number of classes. If the shape is (B, C, H, W) or (B, C, H), the H or product of H
|
||||
and W should be the same as predict.
|
||||
and W should be the same as logits.
|
||||
|
||||
Outputs:
|
||||
Tensor, it's a tensor with the same shape and type as input `predict`.
|
||||
Tensor, it's a tensor with the same shape and type as input `logits`.
|
||||
|
||||
Raises:
|
||||
TypeError: If the data type of ``gamma`` is not float.
|
||||
TypeError: If ``weight`` is not a Tensor.
|
||||
ValueError: If ``target`` dim different from ``predict``.
|
||||
ValueError: If ``target`` channel is not 1 and ``target`` shape is different from ``predict``.
|
||||
ValueError: If ``labels`` dim different from ``logits``.
|
||||
ValueError: If ``labels`` channel is not 1 and ``labels`` shape is different from ``logits``.
|
||||
ValueError: If ``reduction`` is not one of 'none', 'mean', 'sum'.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
|
||||
Example:
|
||||
>>> predict = Tensor([[0.8, 1.4], [0.5, 0.9], [1.2, 0.9]], mstype.float32)
|
||||
>>> target = Tensor([[1], [1], [0]], mstype.int32)
|
||||
>>> logits = Tensor([[0.8, 1.4], [0.5, 0.9], [1.2, 0.9]], mstype.float32)
|
||||
>>> labels = Tensor([[1], [1], [0]], mstype.int32)
|
||||
>>> focalloss = nn.FocalLoss(weight=Tensor([1, 2]), gamma=2.0, reduction='mean')
|
||||
>>> output = focalloss(predict, target)
|
||||
>>> output = focalloss(logits, labels)
|
||||
>>> print(output)
|
||||
0.12516622
|
||||
"""
|
||||
|
@ -1098,6 +1128,8 @@ class FocalLoss(_Loss):
|
|||
self.logsoftmax = nn.LogSoftmax(1)
|
||||
|
||||
def construct(self, predict, target):
|
||||
_check_input_type('logits', predict, Tensor, self.cls_name)
|
||||
_check_input_type('labels', target, Tensor, self.cls_name)
|
||||
targets = target
|
||||
_check_ndim(predict.ndim, targets.ndim)
|
||||
_check_channel_and_shape(predict.shape[1], targets.shape[1])
|
||||
|
|
Loading…
Reference in New Issue