!28488 fix chinese api comments

Merge pull request !28488 from liutongtong9/code_docs_chi_api
This commit is contained in:
i-robot 2022-01-12 03:19:26 +00:00 committed by Gitee
commit b8e1b94eee
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
25 changed files with 129 additions and 99 deletions

View File

@ -3,7 +3,7 @@
.. py:class:: mindspore.Parameter(default_input, name=None, requires_grad=True, layerwise_parallel=False, parallel_optimizer=True)
通常表示网络的参数( `Parameter``Tensor` 的子类)
`Parameter``Tensor`的子类当它们被绑定为Cell的属性时会自动添加到其参数列表中并且可以通过Cell的某些方法获取例如 `cell.get_parameters()`
.. note::
在"semi_auto_parallel"和"auto_parallel"的并行模式下,如果使用 `Initializer` 模块初始化参数,参数的类型将为 `Tensor` :class:`mindspore.ops.AllGather``Tensor` 仅保存张量的形状和类型信息,而不占用内存来保存实际数据。并行场景下存在参数的形状发生变化的情况,用户可以调用 `Parameter``init_data` 方法得到原始数据。如果网络中存在需要部分输入为 `Parameter` 的算子,则不允许这部分输入的 `Parameter` 进行转换。如果在 `Cell` 里初始化一个 `Parameter` 作为 `Cell` 的属性时建议使用默认值None否则 `Parameter``name` 可能与预期不一致。
@ -91,7 +91,7 @@
- **opt_shard_group** (str) - 该参数进行优化器切分时的group。
- **set_sliced** (bool) - 参数初始化时被设定为分片则为True。默认值False。
**返回:**
初始化数据后的 `Parameter` 。如果当前 `Parameter` 已初始化,则更新 `Parameter` 数据。
@ -185,4 +185,4 @@
.. py:method:: unique
:property:
表示参数是否唯一。
表示参数是否唯一。

View File

@ -3,20 +3,25 @@ mindspore.ParameterTuple
.. py:class:: mindspore.ParameterTuple(iterable)
参数元组的类。
继承于tuple用于管理多个Parameter。
.. note::
该类把网络参数存储到参数元组集合中。
.. py:method:: clone(prefix, init='same')
按元素克隆 `ParameterTuple` 中的数值,以生成新的 `ParameterTuple`
逐个对ParameterTuple中的Parameter进行克隆生成新的ParameterTuple
**参数:**
- **prefix** (str) - 参数的命名空间。
- **init** (Union[Tensor, str, numbers.Number]) - 初始化参数的shape和dtype。 `init` 的定义与 `Parameter` API中的定义相同。默认值'same'。
- **prefix** (str) - Parameter的namespace此前缀将会被添加到Parametertuple中的Parameter的name属性中。
- **init** (Union[Tensor, str, numbers.Number]) - 对Parametertuple中Parameter的shape和类型进行克隆并根据传入的`init`设置数值。默认值:'same'。
如果 `init``Tensor` 则新参数的数值与该Tensor相同
如果 `init``numbers.Number` ,则设置新参数的数值为该值;
如果 `init``str` ,则按照 `Initializer` 模块中对应的同名的初始化方法进行数值设定;
如果 `init` 是'same'则新参数的数值与原Parameter相同。
**返回:**
新的参数元组。
新的参数元组。

View File

@ -3,10 +3,9 @@ mindspore.nn.Accuracy
.. py:class:: mindspore.nn.Accuracy(eval_type='classification')
计算'classification'单标签数据分类和'multilabel'标签数据分类的正确率
计算数据分类的正确率,包括二分类和多分类。
此类创建两个局部变量,预测正确的样本数和总样本数,用于计算预测值 `y_pred` 和真实标签 `y` 的匹配频率。
此频率最终作为正确率返回:是一个将预测正确的数目除以总数的幂等操作。
此类创建两个局部变量,预测正确的样本数和总样本数,用于计算 `y_pred``y` 的匹配率此匹配率即为accuracy。
.. math::
\text{accuracy} =\frac{\text{true_positive} + \text{true_negative}}
@ -42,7 +41,7 @@ mindspore.nn.Accuracy
**返回:**
Float计算的结果。
np.float64计算的正确率结果。
**异常:**
@ -55,13 +54,14 @@ mindspore.nn.Accuracy
**参数:**
- **inputs** - 预测值 `y_pred` 和真实标签 `y` `y_pred``y` 支持Tensor、list或numpy.ndarray类型。
- **inputs** - 预测值 `y_pred` 和真实标签 `y` `y_pred` `y` 支持Tensor、list或numpy.ndarray类型。
对于'classification'情况,`y_pred` 在大多数情况下由范围 :math:`[0, 1]` 中的浮点数组成shape为 :math:`(N, C)` ,其中 :math:`N` 是样本数, :math:`C` 是类别数。
对于'classification'情况, `y_pred` 在大多数情况下由范围 :math:`[0, 1]` 中的浮点数组成shape为 :math:`(N, C)` ,其中 :math:`N` 是样本数, :math:`C` 是类别数。
`y` 由整数值组成如果是one_hot编码格式shape是 :math:`(N,C)` 如果是类别索引shape是 :math:`(N,)`
对于'multilabel'情况,`y_pred``y` 只能是值为0或1的one-hot编码格式其中值为1的索引表示正类别。 `y_pred``y` 的shape都是 :math:`(N,C)`
对于'multilabel'情况, `y_pred``y` 只能是值为0或1的one-hot编码格式其中值为1的索引表示正类别。 `y_pred``y` 的shape都是 :math:`(N,C)`
**异常:**
- **ValueError** - inputs的数量不等于2。
- **ValueError** - 当前输入的 `y_pred` 和历史 `y_pred` 类别数不匹配。

View File

@ -3,16 +3,13 @@
.. py:class:: mindspore.nn.Cell(auto_prefix=True, flags=None)
所有神经网络的基类。
MindSpore中神经网络的基本构成单元。模型或神经网络层应当继承该基类。
一个 `Cell` 可以是单一的神经网络单元,如 :class:`mindspore.nn.Conv2d`, :class:`mindspore.nn.ReLU`, :class:`mindspore.nn.BatchNorm` 等,也可以是组成网络的 `Cell` 的结合体。
.. note::
一般情况下,自动微分 (AutoDiff) 算法会自动调用梯度函数,但是如果使用反向传播方法 (bprop method),梯度函数将会被反向传播方法代替。反向传播函数会接收一个包含损失对输出的梯度张量 `dout` 和一个包含前向传播结果的张量 `out` 。反向传播过程需要计算损失对输入的梯度,损失对参数变量的梯度目前暂不支持。反向传播函数必须包含自身参数。
`mindspore.nn` 中神经网络层也是Cell的子类:class:`mindspore.nn.Conv2d`:class:`mindspore.nn.ReLU`:class:`mindspore.nn.BatchNorm` 等。Cell在GRAPH_MODE(静态图模式)下将编译为一张计算图在PYNATIVE_MODE(动态图模式)下作为神经网络的基础模块。
**参数:**
- **auto_prefix** (Cell) 递归地生成作用域。默认值True。
- **auto_prefix** (bool) 是否自动为Cell及其子Cell生成NameSpace。`auto_prefix` 的设置影响网络参数的命名如果设置为True则自动给网络参数的名称添加前缀否则不添加前缀。默认值True。
- **flags** (dict) - Cell的配置信息目前用于绑定Cell和数据集。用户也通过该参数自定义Cell属性。默认值None。
**支持平台:**
@ -83,7 +80,7 @@
**参数:**
- **param** (Parameter) Parameter类型需要被转换类型的输入参数。
- **param** (Parameter) 需要被转换类型的输入参数。
**返回:**
@ -124,7 +121,7 @@
.. py:method:: compile(*inputs)
编译Cell。
编译Cell为计算图输入需与construct中定义的输入一致
**参数:**
@ -132,7 +129,9 @@
.. py:method:: compile_and_run(*inputs)
编译并运行Cell。
编译并运行Cell输入需与construct中定义的输入一致。
不推荐使用该函数建议直接调用Cell实例
**参数:**
@ -156,13 +155,13 @@
.. py:method:: extend_repr()
设置Cell的扩展表示形式
在原有描述基础上扩展Cell的描述
若需要在print时输出个性化的扩展信息请在您的网络中重新实现此方法。
.. py:method:: generate_scope()
为网络中的每个Cell对象生成作用域
为网络中的每个Cell对象生成NameSpace
.. py:method:: get_flags()
@ -174,7 +173,7 @@
.. py:method:: get_parameters(expand=True)
返回一个该Cell中parameter的迭代器。
返回Cell中parameter的迭代器。
**参数:**

View File

@ -3,7 +3,7 @@ mindspore.nn.Fbeta
.. py:class:: mindspore.nn.Fbeta(beta)
计算fbeta评分。
计算Fbeta评分。
Fbeta评分是精度(Precision)和召回率(Recall)的加权平均值。
@ -38,19 +38,24 @@ mindspore.nn.Fbeta
计算fbeta结果。
**参数:**
- **average** (bool) - 是否计算fbeta平均值。默认值False。
**返回:**
numpy.ndarray或numpy.float64计算结果。
numpy.ndarray或numpy.float64计算的Fbeta score结果。
.. py:method:: update(*inputs)
使用预测值 `y_pred` 和真实标签 `y` 更新内部评估结果。
**参数:**
- **inputs** - `y_pred``y``y_pred``y` 支持Tensor、list或numpy.ndarray类型。
通常情况下, `y_pred` 是0到1之间的浮点数列表shape为 :math:`(N, C)` ,其中 :math:`N` 是样本数, :math:`C` 是类别数。
`y` 是整数值如果使用one-hot编码则shape为 :math:`(N,C)` 如果使用类别索引shape是 :math:`(N,)`
**异常:**
- **ValueError** - 当前输入的 `y_pred` 和历史 `y_pred` 类别数不匹配。
- **ValueError** - 预测值和真实值包含的类别不同。

View File

@ -3,7 +3,7 @@ mindspore.nn.Loss
.. py:class:: mindspore.nn.Loss
计算loss的平均值。如果每 :math:`n` 次迭代调用一次 `update` 方法,则评估结果为:
计算loss的平均值。如果每 :math:`n` 次迭代调用一次 `update` 方法,则计算结果为:
.. math::
loss = \frac{\sum_{k=1}^{n}loss_k}{n}

View File

@ -3,14 +3,14 @@ mindspore.nn.MAE
.. py:class:: mindspore.nn.MAE
计算平均绝对误差MAE
计算平均绝对误差MAEMean Absolute Error)。
创建了一个用于测量输入 :math:`x` 和目标 :math:`y` 各元素之间的平均绝对误差MAE的标准
计算输入 :math:`x` 和目标 :math:`y` 各元素之间的平均绝对误差。
.. math::
\text{MAE} = \frac{\sum_{i=1}^n \|y_i - x_i\|}{n}
\text{MAE} = \frac{\sum_{i=1}^n \|y_{pred}_i - y_i\|}{n}
这里, :math:`n` 是bach size。
这里, :math:`n` 是batch size。
**样例:**
@ -36,7 +36,7 @@ mindspore.nn.MAE
**返回:**
numpy.float64计算结果。
numpy.float64计算的MAE的结果。
**异常:**

View File

@ -3,12 +3,12 @@ mindspore.nn.MSE
.. py:class:: mindspore.nn.MSE
测量均方差MSE
测量均方差MSEMean Squared Error)。
创建用于计算输入 :math:`x` 和目标 :math:`y` 中的每个元素的均方差L2范数平方的标准
计算输入 :math:`x` 和目标 :math:`y` 各元素之间的平均平方误差
.. math::
\text{MSE}(x,\ y) = \frac{\sum_{i=1}^n(y_i - x_i)^2}{n}
\text{MSE}(x,\ y) = \frac{\sum_{i=1}^n(y_{pred}_i - y_i)^2}{n}
其中, :math:`n` 为batch size。
@ -34,7 +34,7 @@ mindspore.nn.MSE
**返回:**
numpy.float64计算结果。
numpy.float64计算的MSE的结果。
**异常:**

View File

@ -11,7 +11,7 @@ mindspore.nn.Metric
.. py:method:: clear()
:abstractmethod:
描述了清除内部评估结果的行为
清除内部评估结果。
.. note::
所有子类都必须重写此接口。
@ -19,7 +19,7 @@ mindspore.nn.Metric
.. py:method:: eval()
:abstractmethod:
描述了计算最终评估结果的行为
计算最终评估结果。
.. note::
所有子类都必须重写此接口。
@ -27,7 +27,7 @@ mindspore.nn.Metric
.. py:method:: indexes
:property:
获取当前的 `indexes` 值。默认为None调用 `set_indexes` 可修改 `indexes` 值。
获取当前的 `indexes` 值。默认为None调用 `set_indexes` 方法可修改 `indexes` 值。
.. py:method:: set_indexes(indexes)
@ -47,6 +47,10 @@ mindspore.nn.Metric
:class:`Metric` ,类实例本身。
**异常:**
- **ValueError** - 如果输入的index类型不是list或其元素类型不全为int。
**样例:**
>>> import numpy as np
@ -66,11 +70,11 @@ mindspore.nn.Metric
.. py:method:: update(*inputs)
:abstractmethod:
描述了更新内部评估结果的行为
更新内部评估结果。
.. note::
所有子类都必须重写此接口。
**参数:**
- **inputs** - 可变长度输入参数列表。通常是预测值和对应的真实标签。
- **inputs** - 可变长度输入参数列表。通常是预测值和对应的真实标签。

View File

@ -3,9 +3,9 @@ mindspore.nn.Precision
.. py:class:: mindspore.nn.Precision(eval_type='classification')
计算'classification'单标签数据分类和'multilabel'多标签数据分类的精度
计算数据分类的精度,包括单标签场景和多标签场景
此函数创建两个局部变量 :math:`\text{true_positive}`:math:`\text{false_positive}` 用于计算精度。计算方式:math:`\text{true_positive}` 除以 :math:`\text{true_positive}`:math:`\text{false_positive}` 的和,是一个幂等操作,此值最终作为精度返回。
此函数创建两个局部变量 :math:`\text{true_positive}`:math:`\text{false_positive}` 用于计算精度。计算方式如下:
.. math::
\text{precision} = \frac{\text{true_positive}}{\text{true_positive} + \text{false_positive}}

View File

@ -3,9 +3,9 @@ mindspore.nn.Recall
.. py:class:: mindspore.nn.Recall(eval_type='classification')
计算'classification'单标签数据分类和'multilabel'多标签数据分类的召回率
计算数据分类的召回率,包括单标签场景和多标签场景
recall类创建两个局部变量 :math:`\text{true_positive}`:math:`\text{false_negative}` 用于计算召回率。计算方式为 :math:`\text{true_positive}` 除以 :math:`\text{true_positive}`:math:`\text{false_negative}` 的和,是一个幂等操作,此值最终作为召回返回。
Recall类创建两个局部变量 :math:`\text{true_positive}`:math:`\text{false_negative}` 用于计算召回率。计算方式为
.. math::
\text{recall} = \frac{\text{true_positive}}{\text{true_positive} + \text{false_negative}}

View File

@ -5,12 +5,9 @@ mindspore.nn.TopKCategoricalAccuracy
计算top-k分类正确率。
.. note::
`update` 方法需要接收满足 :math:`(y_{pred}, y)` 格式的输入。如果某些样本具有相同的正确率,则将选择第一个样本。
**参数:**
**k (int)** - 指定要计算的top-k分类正确率
**k (int)** - 计算准确率使用的Top类别数。
**异常:**
@ -19,6 +16,7 @@ mindspore.nn.TopKCategoricalAccuracy
**样例:**
>>> import mindspore
>>> import numpy as np
>>> from mindspore import nn, Tensor
>>>
@ -48,6 +46,9 @@ mindspore.nn.TopKCategoricalAccuracy
使用预测值 `y_pred` 和真实标签 `y` 更新局部变量。
.. note::
`update` 方法需要接收满足 :math:`(y_{pred}, y)` 格式的输入。如果某些样本具有相同的正确率,则将选择第一个样本。
**参数:**
- **inputs** - 输入 `y_pred``y``y_pred``y` 支持Tensor、list或numpy.ndarray类型。

View File

@ -15,6 +15,10 @@ mindspore.nn.get_metric_fn
metric对象metric方法的类实例。
**样例**
**异常**
- **TypeError** - 入参`metric`的类型不是None, dict或set。
**样例:**
>>> from mindspore import nn
>>> metric = nn.get_metric_fn('precision', eval_type='classification')

View File

@ -9,6 +9,7 @@ mindspore.nn.rearrange_inputs
**样例:**
>>> from mindspore.nn import rearrange_inputs
>>> class RearrangeInputsExample:
... def __init__(self):
... self._indexes = None

View File

@ -67,7 +67,8 @@ def init_to_value(init):
class Parameter(Tensor_):
r"""
An object holding weights of cells, after initialized `Parameter` is a subtype of `Tensor`.
`Parameter` is a `Tensor` subclass, when they are assigned as Cell attributes they are automatically added to
the list of its parameters, and will appear e.g. in `cell.get_parameters()` iterator.
Note:
In auto_parallel mode of "semi_auto_parallel" and "auto_parallel", if init `Parameter` by
@ -644,7 +645,7 @@ class Parameter(Tensor_):
class ParameterTuple(tuple):
"""
Class for storing tuple of parameters.
Inherited from tuple, ParameterTuple is used to save multiple parameter.
Note:
It is used to store the parameters of the network into the parameter tuple collection.
@ -678,11 +679,17 @@ class ParameterTuple(tuple):
Clone the parameters in ParameterTuple element-wisely to generate a new ParameterTuple.
Args:
prefix (str): Namespace of parameter.
init (Union[Tensor, str, numbers.Number]): Initialize the shape and dtype of the parameters.
The definition of `init` is the same as in `Parameter` API. If `init` is 'same', the
parameters in the new parameter tuple are the same as those in the original parameter tuple.
Default: 'same'.
prefix (str): Namespace of parameter, the prefix string will be added to the names of parameters
in parametertuple.
init (Union[Tensor, str, numbers.Number]): Clone the shape and dtype of Parameters in ParameterTuple and
set data according to `init`. Default: 'same'.
If `init` is a `Tensor` , set the new Parameter data to the input Tensor.
If `init` is `numbers.Number` , set the new Parameter data to the input number.
If `init` is a `str`, data will be seted according to the initialization method of the same name in
the `Initializer`.
If `init` is 'same', the new Parameter has the same value with the original Parameter.
Returns:
Tuple, the new Parameter tuple.

View File

@ -39,23 +39,18 @@ from ..parallel._tensor import _load_tensor_by_layout
class Cell(Cell_):
"""
Base class for all neural networks.
The basic building block of neural networks in MindSpore. The model or neural network layer should inherit this
base class.
A 'Cell' could be a single neural network cell, such as :class:`mindspore.nn.Conv2d`, :class:`mindspore.nn.ReLU`,
:class:`mindspore.nn.BatchNorm`, etc. or a composition of cells to constructing a network.
Note:
In general, the autograd algorithm will automatically generate the implementation of the gradient function,
but if back-propagation(bprop) method is implemented, the gradient function will be replaced by the bprop.
The bprop implementation will receive a tensor `dout` containing the gradient of the loss w.r.t.
the output, and a tensor `out` containing the forward result. The bprop needs to compute the
gradient of the loss w.r.t. the inputs, gradient of the loss w.r.t. Parameter variables are not supported
currently. The bprop method must contain the self parameter.
Layers in `mindspore.nn` are also the subclass of Cell, such as :class:`mindspore.nn.Conv2d`,
:class:`mindspore.nn.ReLU`, :class:`mindspore.nn.BatchNorm`, etc. Cell will be compiled into a calculation
graph in GRAPH_MODE (static graph mode) and used as the basic module of neural networks in
PYNATIVE_MODE (dynamic graph mode).
Args:
auto_prefix (bool): Recursively generate namespaces. It will affect the name of the parameter in the network.
If set to True, the network parameter name will be prefixed, otherwise it will not.
Default: True.
auto_prefix (bool): Whether to automatically generate NameSpace for Cell and its subcells. It will affect the
name of the parameter in the network. If set to True, the network parameter
name will be prefixed, otherwise it will not. Default: True.
flags (dict): Network configuration information, currently it is used for the binding of network and dataset.
Users can also customize network attributes by this parameter. Default: None.
@ -673,7 +668,7 @@ class Cell(Cell_):
def extend_repr(self):
"""
Sets the extended representation of the Cell.
Expand the description of Cell.
To print customized extended information, re-implement this method in your own cells.
"""
@ -781,7 +776,7 @@ class Cell(Cell_):
def compile(self, *inputs):
"""
Compiles cell.
Compile Cell as a computation graph, the input must be consistent with the input defined in construct.
Args:
inputs (tuple): Inputs of the Cell object.
@ -790,7 +785,9 @@ class Cell(Cell_):
def compile_and_run(self, *inputs):
"""
Compiles and runs cell.
Compile and run Cell, the input must be consistent with the input defined in construct.
Note: It is not recommended to call directly.
Args:
inputs (tuple): Inputs of the Cell object.

View File

@ -136,6 +136,9 @@ def get_metrics(metrics):
Returns:
dict, the key is metric name, the value is class instance of metric method.
Raises:
TypeError: If the type of argument 'metrics' is not None, dict or set.
"""
if metrics is None:
return metrics

View File

@ -21,9 +21,8 @@ class Accuracy(EvaluationBase):
r"""
Calculates the accuracy for classification and multilabel data.
The accuracy class has two local variables, the correct number and the total number of samples, that are used to
compute the frequency with which `y_pred` matches `y`. This frequency is ultimately returned as the accuracy: an
idempotent operation that simply divides the correct number by the total number.
The accuracy class creates two local variables, the correct number and the total number that are used to
compute the frequency with which y_pred matches y. This frequency is the accuracy.
.. math::
\text{accuracy} =\frac{\text{true_positive} + \text{true_negative}}
@ -80,6 +79,8 @@ class Accuracy(EvaluationBase):
Raises:
ValueError: If the number of the inputs is not 2.
ValueError: class numbers of last input predicted data and current predicted data not match.
"""
if len(inputs) != 2:
raise ValueError("For 'Accuracy.update', it needs 2 inputs (predicted value, true value), "
@ -115,7 +116,7 @@ class Accuracy(EvaluationBase):
Computes the accuracy.
Returns:
Float, the computed result.
np.float64, the computed result.
Raises:
RuntimeError: If the sample size is 0.

View File

@ -25,7 +25,7 @@ class MAE(Metric):
in the input: :math:`x` and the target: :math:`y`.
.. math::
\text{MAE} = \frac{\sum_{i=1}^n \|y_i - x_i\|}{n}
\text{MAE} = \frac{\sum_{i=1}^n \|y_{pred}_i - y_i\|}{n}
where :math:`n` is batch size.
@ -101,7 +101,7 @@ class MSE(Metric):
each element in the prediction and the ground truth: :math:`x` and: :math:`y`.
.. math::
\text{MSE}(x,\ y) = \frac{\sum_{i=1}^n(y_i - x_i)^2}{n}
\text{MSE}(x,\ y) = \frac{\sum_{i=1}^n(y_{pred}_i - y_i)^2}{n}
where :math:`n` is batch size.

View File

@ -21,7 +21,7 @@ from .metric import Metric, rearrange_inputs
class Fbeta(Metric):
r"""
Calculates the fbeta score.
Calculates the Fbeta score.
Fbeta score is a weighted mean of precision and recall.
@ -74,6 +74,10 @@ class Fbeta(Metric):
and the shape is :math:`(N, C)`, where :math:`N` is the number of cases and :math:`C`
is the number of categories. y contains values of integers. The shape is :math:`(N, C)`
if one-hot encoding is used. Shape can also be :math:`(N,)` if category index is used.
Raises:
ValueError: class numbers of last input predicted data and current predicted data not match.
ValueError: If the predicted value and true value contain different classes.
"""
if len(inputs) != 2:
raise ValueError("For 'Fbeta.update', it needs 2 inputs (predicted value, true value), "

View File

@ -167,6 +167,9 @@ class Metric(metaclass=ABCMeta):
Outputs:
:class:`Metric`, its original Class instance.
Raises:
ValueError: If the type of input 'indexes' is not a list or its elements are not all int.
Examples:
>>> import numpy as np
>>> from mindspore import nn, Tensor

View File

@ -29,7 +29,7 @@ class Perplexity(Metric):
Args:
ignore_label (Union[int, None]): Index of an invalid label to be ignored when counting. If set to `None`,
it will include all entries. Default: None.
it will include all entries. Default: None.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

View File

@ -26,9 +26,7 @@ class Precision(EvaluationBase):
Calculates precision for classification and multilabel data.
The precision function creates two local variables, :math:`\text{true_positive}` and
:math:`\text{false_positive}`, that are used to compute the precision. This value is
ultimately returned as the precision, an idempotent operation that simply divides
:math:`\text{true_positive}` by the sum of :math:`\text{true_positive}` and :math:`\text{false_positive}`.
:math:`\text{false_positive}`, that are used to compute the precision. The calculation formula is:
.. math::
\text{precision} = \frac{\text{true_positive}}{\text{true_positive} + \text{false_positive}}

View File

@ -26,10 +26,7 @@ class Recall(EvaluationBase):
Calculates recall for classification and multilabel data.
The recall class creates two local variables, :math:`\text{true_positive}` and :math:`\text{false_negative}`,
that are used to compute the recall. This value is ultimately returned as the recall, an idempotent operation
that simply divides :math:`\text{true_positive}` by the sum of :math:`\text{true_positive}` and
:math:`\text{false_negative}`.
that are used to compute the recall. The calculation formula is:
.. math::
\text{recall} = \frac{\text{true_positive}}{\text{true_positive} + \text{false_negative}}

View File

@ -21,10 +21,6 @@ class TopKCategoricalAccuracy(Metric):
"""
Calculates the top-k categorical accuracy.
Note:
The method `update` must receive input of the form :math:`(y_{pred}, y)`. If some samples have
the same accuracy, the first sample will be chosen.
Args:
k (int): Specifies the top-k categorical accuracy to compute.
@ -36,6 +32,7 @@ class TopKCategoricalAccuracy(Metric):
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import mindspore
>>> import numpy as np
>>> from mindspore import nn, Tensor
>>>
@ -76,6 +73,10 @@ class TopKCategoricalAccuracy(Metric):
and the shape is :math:`(N, C)`, where :math:`N` is the number of cases and :math:`C`
is the number of categories. `y` contains values of integers. The shape is :math:`(N, C)`
if one-hot encoding is used. Shape can also be :math:`(N,)` if category index is used.
Note:
The method `update` must receive input of the form :math:`(y_{pred}, y)`. If some samples have
the same accuracy, the first sample will be chosen.
"""
if len(inputs) != 2:
raise ValueError("For 'TopKCategoricalAccuracy.update', "