forked from mindspore-Ecosystem/mindspore
!47572 modify marginrankingloss errors
Merge pull request !47572 from 冯一航/modify_marginrankingloss_errors
This commit is contained in:
commit
f10f5e722b
|
@ -21,7 +21,6 @@ import numpy as np
|
|||
import mindspore.ops as ops
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
from mindspore.ops.operations import nn_ops as NN_OPS
|
||||
from mindspore.ops.operations import image_ops as IMG
|
||||
import mindspore.common.dtype as mstype
|
||||
|
@ -3476,6 +3475,17 @@ def _get_loss(x, reduction, cls_name, weights=1.0):
|
|||
return x
|
||||
|
||||
|
||||
def _check_type_and_shape_same(param_name1, input_data1, param_name2, input_data2, cls_name):
|
||||
"""check input1 and input2 type and shape same"""
|
||||
if input_data1.dtype != input_data2.dtype:
|
||||
raise TypeError(f'For {cls_name}, the {param_name1} dtype should be equal to {param_name2} dtype, '
|
||||
f'but got {param_name1} dtype:{input_data1.dtype}, {param_name2} dtype:{input_data2.dtype}.')
|
||||
if input_data1.shape != input_data2.shape:
|
||||
raise ValueError(f'For {cls_name}, the {param_name1} shape should be equal to {param_name2} shape, '
|
||||
f'but got {param_name1} shape:{input_data1.shape}, {param_name2} shape:{input_data2.shape}.')
|
||||
return 0
|
||||
|
||||
|
||||
def margin_ranking_loss(input1, input2, target, margin=0.0, reduction='mean'):
|
||||
"""
|
||||
MarginRankingLoss creates a criterion that measures the loss.
|
||||
|
@ -3487,9 +3497,9 @@ def margin_ranking_loss(input1, input2, target, margin=0.0, reduction='mean'):
|
|||
_check_is_tensor('input2', input2, "margin_ranking_loss")
|
||||
_check_is_tensor('target', target, "margin_ranking_loss")
|
||||
maximum = P.Maximum()
|
||||
inner.same_type_shape_(input1, input2)
|
||||
inner.same_type_shape_(target, input1)
|
||||
x = maximum(0, -target * (input1 - input2) + margin)
|
||||
_check_type_and_shape_same('input1', input1, 'input2', input2, 'margin_ranking_loss')
|
||||
_check_type_and_shape_same('target', target, 'input1', input1, 'margin_ranking_loss')
|
||||
x = maximum(-target * (input1 - input2) + margin, 0)
|
||||
return _get_loss(x, reduction, "margin_ranking_loss")
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue