modify marginrankingloss errors

This commit is contained in:
fengyihang 2023-01-05 20:28:48 +08:00
parent 7ccb517428
commit 675963844f
1 changed files with 14 additions and 4 deletions

View File

@ -21,7 +21,6 @@ import numpy as np
import mindspore.ops as ops
from mindspore.ops.primitive import constexpr
from mindspore.ops import operations as P
from mindspore.ops.operations import _inner_ops as inner
from mindspore.ops.operations import nn_ops as NN_OPS
from mindspore.ops.operations import image_ops as IMG
import mindspore.common.dtype as mstype
@ -3470,6 +3469,17 @@ def _get_loss(x, reduction, cls_name, weights=1.0):
return x
def _check_type_and_shape_same(param_name1, input_data1, param_name2, input_data2, cls_name):
"""check input1 and input2 type and shape same"""
if input_data1.dtype != input_data2.dtype:
raise TypeError(f'For {cls_name}, the {param_name1} dtype should be equal to {param_name2} dtype, '
f'but got {param_name1} dtype:{input_data1.dtype}, {param_name2} dtype:{input_data2.dtype}.')
if input_data1.shape != input_data2.shape:
raise ValueError(f'For {cls_name}, the {param_name1} shape should be equal to {param_name2} shape, '
f'but got {param_name1} shape:{input_data1.shape}, {param_name2} shape:{input_data2.shape}.')
return 0
def margin_ranking_loss(input1, input2, target, margin=0.0, reduction='mean'):
"""
MarginRankingLoss creates a criterion that measures the loss.
@ -3481,9 +3491,9 @@ def margin_ranking_loss(input1, input2, target, margin=0.0, reduction='mean'):
_check_is_tensor('input2', input2, "margin_ranking_loss")
_check_is_tensor('target', target, "margin_ranking_loss")
maximum = P.Maximum()
inner.same_type_shape_(input1, input2)
inner.same_type_shape_(target, input1)
x = maximum(0, -target * (input1 - input2) + margin)
_check_type_and_shape_same('input1', input1, 'input2', input2, 'margin_ranking_loss')
_check_type_and_shape_same('target', target, 'input1', input1, 'margin_ranking_loss')
x = maximum(-target * (input1 - input2) + margin, 0)
return _get_loss(x, reduction, "margin_ranking_loss")