!49099 [bugfix] update hinge embedding loss
Merge pull request !49099 from shaojunsong/update_hel
This commit is contained in:
commit
2a16f3992c
|
@ -35,8 +35,8 @@ mindspore.nn.HingeEmbeddingLoss
|
||||||
Tensor或Tensor scalar,根据 :math:`reduction` 计算的loss。
|
Tensor或Tensor scalar,根据 :math:`reduction` 计算的loss。
|
||||||
|
|
||||||
异常:
|
异常:
|
||||||
- **TypeError** - `logits` 不是数据类型为float的Tensor。
|
- **TypeError** - `logits` 不是Tensor。
|
||||||
- **TypeError** - `labels` 不是数据类型为float的Tensor。
|
- **TypeError** - `labels` 不是Tensor。
|
||||||
- **TypeError** - `margin` 不是float或int。
|
- **TypeError** - `margin` 不是float或int。
|
||||||
- **ValueError** - `labels` 和 `logits` shape不一致。
|
- **ValueError** - `labels` 和 `logits` shape不一致且不能广播。
|
||||||
- **ValueError** - `reduction` 不是"none"、"mean"或者"sum"。
|
- **ValueError** - `reduction` 不是"none"、"mean"或者"sum"。
|
||||||
|
|
|
@ -33,8 +33,8 @@ mindspore.ops.hinge_embedding_loss
|
||||||
Tensor或Tensor scalar,根据 :math:`reduction` 计算的loss。
|
Tensor或Tensor scalar,根据 :math:`reduction` 计算的loss。
|
||||||
|
|
||||||
异常:
|
异常:
|
||||||
- **TypeError** - `inputs` 不是数据类型为float的Tensor。
|
- **TypeError** - `inputs` 不是Tensor。
|
||||||
- **TypeError** - `targets` 不是数据类型为float的Tensor。
|
- **TypeError** - `targets` 不是Tensor。
|
||||||
- **TypeError** - `margin` 不是float或者int。
|
- **TypeError** - `margin` 不是float或者int。
|
||||||
- **ValueError** - `inputs` 和 `targets` shape不一致。
|
- **ValueError** - `inputs` 和 `targets` shape不一致且不能广播。
|
||||||
- **ValueError** - `reduction` 不是"none"、"mean"或者"sum"。
|
- **ValueError** - `reduction` 不是"none"、"mean"或者"sum"。
|
||||||
|
|
|
@ -2509,10 +2509,10 @@ class HingeEmbeddingLoss(LossBase):
|
||||||
Tensor or Tensor scalar, the computed loss depending on `reduction`.
|
Tensor or Tensor scalar, the computed loss depending on `reduction`.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: If `logits` is not a Tensor of floats.
|
TypeError: If `logits` is not a Tensor.
|
||||||
TypeError: If `labels` is not a Tensor of floats.
|
TypeError: If `labels` is not a Tensor.
|
||||||
TypeError: If `margin` is not a float or int.
|
TypeError: If `margin` is not a float or int.
|
||||||
ValueError: If `labels` does not have the same shape as `logits`.
|
ValueError: If `labels` does not have or could not broadcast to the same shape as `logits`.
|
||||||
ValueError: If `reduction` is not one of 'none', 'mean', 'sum'.
|
ValueError: If `reduction` is not one of 'none', 'mean', 'sum'.
|
||||||
|
|
||||||
Supported Platforms:
|
Supported Platforms:
|
||||||
|
|
|
@ -4153,10 +4153,10 @@ def hinge_embedding_loss(inputs, targets, margin=1.0, reduction='mean'):
|
||||||
Tensor or Tensor scalar, the computed loss depending on `reduction`.
|
Tensor or Tensor scalar, the computed loss depending on `reduction`.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: If `inputs` is not a Tensor of floats.
|
TypeError: If `inputs` is not a Tensor.
|
||||||
TypeError: If `targets` is not a Tensor of floats.
|
TypeError: If `targets` is not a Tensor.
|
||||||
TypeError: If `margin` is not a float or int.
|
TypeError: If `margin` is not a float or int.
|
||||||
ValueError: If `targets` does not have the same shape as `inputs`.
|
ValueError: If `targets` does not have or could not broadcast to the same shape as `inputs`.
|
||||||
ValueError: If `reduction` is not one of 'none', 'mean', 'sum'.
|
ValueError: If `reduction` is not one of 'none', 'mean', 'sum'.
|
||||||
|
|
||||||
Supported Platforms:
|
Supported Platforms:
|
||||||
|
@ -4180,14 +4180,11 @@ def hinge_embedding_loss(inputs, targets, margin=1.0, reduction='mean'):
|
||||||
if reduction not in ['none', 'mean', 'sum']:
|
if reduction not in ['none', 'mean', 'sum']:
|
||||||
raise ValueError(f"For 'HingeEmbeddingLoss', 'reduction' must be one of 'none', 'mean', 'sum',"
|
raise ValueError(f"For 'HingeEmbeddingLoss', 'reduction' must be one of 'none', 'mean', 'sum',"
|
||||||
f"but got {reduction}.")
|
f"but got {reduction}.")
|
||||||
|
if not isinstance(inputs, Tensor):
|
||||||
|
raise TypeError(f"For 'HingeEmbeddingLoss', the first input must be a Tensor, but got {type(inputs)}.")
|
||||||
|
if not isinstance(targets, Tensor):
|
||||||
|
raise TypeError(f"For 'HingeEmbeddingLoss', the second input must be a Tensor, but got {type(targets)}.")
|
||||||
inputs_dtype = inputs.dtype
|
inputs_dtype = inputs.dtype
|
||||||
targets_dtype = targets.dtype
|
|
||||||
if inputs_dtype not in mstype.float_type:
|
|
||||||
raise TypeError(f"For 'HingeEmbeddingLoss', the dtype of the first input must be float, but got "
|
|
||||||
f"{inputs_dtype}.")
|
|
||||||
if targets_dtype not in mstype.float_type:
|
|
||||||
raise TypeError(f"For 'HingeEmbeddingLoss', the dtype of the second input must be float, but got "
|
|
||||||
f"{targets_dtype}.")
|
|
||||||
min_val = Tensor(0, inputs_dtype)
|
min_val = Tensor(0, inputs_dtype)
|
||||||
pos_index = targets > 0
|
pos_index = targets > 0
|
||||||
neg_index = targets < 0
|
neg_index = targets < 0
|
||||||
|
|
Loading…
Reference in New Issue