update hel

This commit is contained in:
shaojunsong 2023-02-18 18:09:14 +08:00
parent c50aaccce2
commit 6691fa9dbf
4 changed files with 16 additions and 19 deletions

View File

@ -35,8 +35,8 @@ mindspore.nn.HingeEmbeddingLoss
Tensor或Tensor scalar根据 :math:`reduction` 计算的loss。
异常:
- **TypeError** - `logits` 不是数据类型为float的Tensor。
- **TypeError** - `labels` 不是数据类型为float的Tensor。
- **TypeError** - `logits` 不是Tensor。
- **TypeError** - `labels` 不是Tensor。
- **TypeError** - `margin` 不是float或int。
- **ValueError** - `labels``logits` shape不一致。
- **ValueError** - `labels``logits` shape不一致且不能广播
- **ValueError** - `reduction` 不是"none"、"mean"或者"sum"。

View File

@ -33,8 +33,8 @@ mindspore.ops.hinge_embedding_loss
Tensor或Tensor scalar根据 :math:`reduction` 计算的loss。
异常:
- **TypeError** - `inputs` 不是数据类型为float的Tensor。
- **TypeError** - `targets` 不是数据类型为float的Tensor。
- **TypeError** - `inputs` 不是Tensor。
- **TypeError** - `targets` 不是Tensor。
- **TypeError** - `margin` 不是float或者int。
- **ValueError** - `inputs``targets` shape不一致。
- **ValueError** - `inputs``targets` shape不一致且不能广播
- **ValueError** - `reduction` 不是"none"、"mean"或者"sum"。

View File

@ -2505,10 +2505,10 @@ class HingeEmbeddingLoss(LossBase):
Tensor or Tensor scalar, the computed loss depending on `reduction`.
Raises:
TypeError: If `logits` is not a Tensor of floats.
TypeError: If `labels` is not a Tensor of floats.
TypeError: If `logits` is not a Tensor.
TypeError: If `labels` is not a Tensor.
TypeError: If `margin` is not a float or int.
ValueError: If `labels` does not have the same shape as `logits`.
ValueError: If `labels` does not have or could not broadcast to the same shape as `logits`.
ValueError: If `reduction` is not one of 'none', 'mean', 'sum'.
Supported Platforms:

View File

@ -4099,10 +4099,10 @@ def hinge_embedding_loss(inputs, targets, margin=1.0, reduction='mean'):
Tensor or Tensor scalar, the computed loss depending on `reduction`.
Raises:
TypeError: If `inputs` is not a Tensor of floats.
TypeError: If `targets` is not a Tensor of floats.
TypeError: If `inputs` is not a Tensor.
TypeError: If `targets` is not a Tensor.
TypeError: If `margin` is not a float or int.
ValueError: If `targets` does not have the same shape as `inputs`.
ValueError: If `targets` does not have or could not broadcast to the same shape as `inputs`.
ValueError: If `reduction` is not one of 'none', 'mean', 'sum'.
Supported Platforms:
@ -4126,14 +4126,11 @@ def hinge_embedding_loss(inputs, targets, margin=1.0, reduction='mean'):
if reduction not in ['none', 'mean', 'sum']:
raise ValueError(f"For 'HingeEmbeddingLoss', 'reduction' must be one of 'none', 'mean', 'sum',"
f"but got {reduction}.")
if not isinstance(inputs, Tensor):
raise TypeError(f"For 'HingeEmbeddingLoss', the first input must be a Tensor, but got {type(inputs)}.")
if not isinstance(targets, Tensor):
raise TypeError(f"For 'HingeEmbeddingLoss', the second input must be a Tensor, but got {type(targets)}.")
inputs_dtype = inputs.dtype
targets_dtype = targets.dtype
if inputs_dtype not in mstype.float_type:
raise TypeError(f"For 'HingeEmbeddingLoss', the dtype of the first input must be float, but got "
f"{inputs_dtype}.")
if targets_dtype not in mstype.float_type:
raise TypeError(f"For 'HingeEmbeddingLoss', the dtype of the second input must be float, but got "
f"{targets_dtype}.")
min_val = Tensor(0, inputs_dtype)
pos_index = targets > 0
neg_index = targets < 0