fix doc
This commit is contained in:
parent
57045d2248
commit
b40fae3974
|
@ -31,11 +31,7 @@ mindspore.ops.kl_div
|
|||
参数:
|
||||
- **logits** (Tensor) - 数据类型支持float16、float32或float64。
|
||||
- **labels** (Tensor) - 标签Tensor,与 `logits` 的shape和数据类型相同。
|
||||
- **reduction** (str) - 指定输出结果的计算方式。默认值: "mean"。
|
||||
|
||||
- 在Ascend平台上, `reduction` 的可选值为"batchmean"、"none"或"sum"。
|
||||
- 在GPU平台上, `reduction` 的可选值为"mean"、"batchmean"、"none"或"sum"。
|
||||
- 在CPU平台上, `reduction` 的可选值为"mean"、"batchmean"、"none"或"sum"。
|
||||
- **reduction** (str) - 指定输出结果的计算方式,可选值为"mean"、"batchmean"、"none"或"sum"。默认值: "mean"。
|
||||
|
||||
返回:
|
||||
Tensor或标量。如果 `reduction` 为 'none' ,则输出为Tensor且与 `logits` 的shape相同。否则为标量。
|
||||
|
|
|
@ -561,7 +561,7 @@ def kl_div(logits, labels, reduction='mean'):
|
|||
|
||||
Returns:
|
||||
Tensor or Scalar, if `reduction` is 'none', then output is a tensor and has the same shape as `logits`.
|
||||
Otherwise it is a scalar.
|
||||
Otherwise, it is a scalar.
|
||||
|
||||
Raises:
|
||||
TypeError: If `reduction` is not a str.
|
||||
|
@ -569,29 +569,33 @@ def kl_div(logits, labels, reduction='mean'):
|
|||
TypeError: If dtype of `logits` or `labels` is not float32.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU`` ``GPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Note:
|
||||
Currently it does not support float64 input on `Ascend`.
|
||||
It behaves the same as the mathematical definition only when `reduction` is set to `batchmean`.
|
||||
|
||||
Examples:
|
||||
>>> class Net(nn.Cell):
|
||||
... def __init__(self):
|
||||
... super(Net, self).__init__()
|
||||
...
|
||||
... def construct(self, logits, labels):
|
||||
... result = mindspore.ops.functional.kl_div(logits, labels, 'mean')
|
||||
... return result
|
||||
...
|
||||
>>> net = Net()
|
||||
>>> logits = Tensor(np.array([0.2, 0.7, 0.1]), mindspore.float32)
|
||||
>>> labels = Tensor(np.array([0., 1., 0.]), mindspore.float32)
|
||||
>>> output = net(logits, labels)
|
||||
>>> output = mindspore.ops.kl_div(logits, labels, 'mean')
|
||||
>>> print(output)
|
||||
-0.23333333
|
||||
"""
|
||||
if reduction == 'batchmean':
|
||||
kl_div_sum = P.KLDivLoss(reduction='sum')(logits, labels)
|
||||
batch_size = logits.shape[0]
|
||||
shape = P.TensorShape()(logits)
|
||||
batch_size = shape[0]
|
||||
return kl_div_sum / batch_size
|
||||
|
||||
if reduction == 'mean':
|
||||
kl_div_sum = P.KLDivLoss(reduction='sum')(logits, labels)
|
||||
shape = P.TensorShape()(logits)
|
||||
total_size = 1
|
||||
for dim in shape:
|
||||
total_size = total_size * dim
|
||||
return kl_div_sum / total_size
|
||||
|
||||
return P.KLDivLoss(reduction=reduction)(logits, labels)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue