!49099 [bugfix] update hinge embedding loss

Merge pull request !49099 from shaojunsong/update_hel
This commit is contained in:
i-robot 2023-02-27 07:38:04 +00:00 committed by Gitee
commit 2a16f3992c
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 16 additions and 19 deletions

View File

@ -35,8 +35,8 @@ mindspore.nn.HingeEmbeddingLoss
Tensor或Tensor scalar根据 :math:`reduction` 计算的loss。 Tensor或Tensor scalar根据 :math:`reduction` 计算的loss。
异常: 异常:
- **TypeError** - `logits` 不是数据类型为float的Tensor。 - **TypeError** - `logits` 不是Tensor。
- **TypeError** - `labels` 不是数据类型为float的Tensor。 - **TypeError** - `labels` 不是Tensor。
- **TypeError** - `margin` 不是float或int。 - **TypeError** - `margin` 不是float或int。
- **ValueError** - `labels``logits` shape不一致。 - **ValueError** - `labels``logits` shape不一致且不能广播
- **ValueError** - `reduction` 不是"none"、"mean"或者"sum"。 - **ValueError** - `reduction` 不是"none"、"mean"或者"sum"。

View File

@ -33,8 +33,8 @@ mindspore.ops.hinge_embedding_loss
Tensor或Tensor scalar根据 :math:`reduction` 计算的loss。 Tensor或Tensor scalar根据 :math:`reduction` 计算的loss。
异常: 异常:
- **TypeError** - `inputs` 不是数据类型为float的Tensor。 - **TypeError** - `inputs` 不是Tensor。
- **TypeError** - `targets` 不是数据类型为float的Tensor。 - **TypeError** - `targets` 不是Tensor。
- **TypeError** - `margin` 不是float或者int。 - **TypeError** - `margin` 不是float或者int。
- **ValueError** - `inputs``targets` shape不一致。 - **ValueError** - `inputs``targets` shape不一致且不能广播
- **ValueError** - `reduction` 不是"none"、"mean"或者"sum"。 - **ValueError** - `reduction` 不是"none"、"mean"或者"sum"。

View File

@ -2509,10 +2509,10 @@ class HingeEmbeddingLoss(LossBase):
Tensor or Tensor scalar, the computed loss depending on `reduction`. Tensor or Tensor scalar, the computed loss depending on `reduction`.
Raises: Raises:
TypeError: If `logits` is not a Tensor of floats. TypeError: If `logits` is not a Tensor.
TypeError: If `labels` is not a Tensor of floats. TypeError: If `labels` is not a Tensor.
TypeError: If `margin` is not a float or int. TypeError: If `margin` is not a float or int.
ValueError: If `labels` does not have the same shape as `logits`. ValueError: If `labels` does not have or could not broadcast to the same shape as `logits`.
ValueError: If `reduction` is not one of 'none', 'mean', 'sum'. ValueError: If `reduction` is not one of 'none', 'mean', 'sum'.
Supported Platforms: Supported Platforms:

View File

@ -4153,10 +4153,10 @@ def hinge_embedding_loss(inputs, targets, margin=1.0, reduction='mean'):
Tensor or Tensor scalar, the computed loss depending on `reduction`. Tensor or Tensor scalar, the computed loss depending on `reduction`.
Raises: Raises:
TypeError: If `inputs` is not a Tensor of floats. TypeError: If `inputs` is not a Tensor.
TypeError: If `targets` is not a Tensor of floats. TypeError: If `targets` is not a Tensor.
TypeError: If `margin` is not a float or int. TypeError: If `margin` is not a float or int.
ValueError: If `targets` does not have the same shape as `inputs`. ValueError: If `targets` does not have or could not broadcast to the same shape as `inputs`.
ValueError: If `reduction` is not one of 'none', 'mean', 'sum'. ValueError: If `reduction` is not one of 'none', 'mean', 'sum'.
Supported Platforms: Supported Platforms:
@ -4180,14 +4180,11 @@ def hinge_embedding_loss(inputs, targets, margin=1.0, reduction='mean'):
if reduction not in ['none', 'mean', 'sum']: if reduction not in ['none', 'mean', 'sum']:
raise ValueError(f"For 'HingeEmbeddingLoss', 'reduction' must be one of 'none', 'mean', 'sum'," raise ValueError(f"For 'HingeEmbeddingLoss', 'reduction' must be one of 'none', 'mean', 'sum',"
f"but got {reduction}.") f"but got {reduction}.")
if not isinstance(inputs, Tensor):
raise TypeError(f"For 'HingeEmbeddingLoss', the first input must be a Tensor, but got {type(inputs)}.")
if not isinstance(targets, Tensor):
raise TypeError(f"For 'HingeEmbeddingLoss', the second input must be a Tensor, but got {type(targets)}.")
inputs_dtype = inputs.dtype inputs_dtype = inputs.dtype
targets_dtype = targets.dtype
if inputs_dtype not in mstype.float_type:
raise TypeError(f"For 'HingeEmbeddingLoss', the dtype of the first input must be float, but got "
f"{inputs_dtype}.")
if targets_dtype not in mstype.float_type:
raise TypeError(f"For 'HingeEmbeddingLoss', the dtype of the second input must be float, but got "
f"{targets_dtype}.")
min_val = Tensor(0, inputs_dtype) min_val = Tensor(0, inputs_dtype)
pos_index = targets > 0 pos_index = targets > 0
neg_index = targets < 0 neg_index = targets < 0