fix comments

This commit is contained in:
liutongtong 2022-01-07 19:51:52 +08:00
parent 0a52e1723c
commit 03b856c952
17 changed files with 52 additions and 15 deletions

View File

@ -149,7 +149,7 @@ def check_number(arg_value, value, rel, arg_type=int, arg_name=None, prim_name=N
Check argument integer.
Usage:
- number = check_number(number, 0, Rel.GE, "number", None) # number >= 0
- number = check_number(number, 0, Rel.GE, "number", None)
"""
rel_fn = Rel.get_fns(rel)
prim_name = f' in `{prim_name}`' if prim_name else ''

View File

@ -792,8 +792,8 @@ class _CellGraphExecutor:
def ms_memory_recycle():
"""
Recycle memory used by MindSpore.
When train multi Neural network models in one process, memory used by mindspore is very large,
this is because mindspore cached runtime memory for every model.
When train multi Neural network models in one process, memory used by MindSpore is very large,
this is because MindSpore cached runtime memory for every model.
To recycle these cached memory, users can call this function after training of one model.
"""
if ms_compile_cache:

View File

@ -61,7 +61,7 @@ def set_seed(seed):
>>> import mindspore.ops as ops
>>> from mindspore import Tensor, set_seed, Parameter
>>> from mindspore.common.initializer import initializer
>>>
>>> import mindspore as ms
>>> # Note: (1) Please make sure the code is running in PYNATIVE MODE;
>>> # (2) Because Composite-level ops need parameters to be Tensors, for below examples,
>>> # when using ops.uniform operator, minval and maxval are initialised as:
@ -129,7 +129,7 @@ def set_seed(seed):
>>> # condition 5.
>>> c1 = ops.uniform((1, 4), minval, maxval, seed=2) # C1
>>> c2 = ops.uniform((1, 4), minval, maxval, seed=2) # C2
>>> # Rerun the program will get the same results:
>>> # Rerun the program will get the different results:
>>> c1 = ops.uniform((1, 4), minval, maxval, seed=2) # C1
>>> c2 = ops.uniform((1, 4), minval, maxval, seed=2) # C2
>>>

View File

@ -34,6 +34,9 @@ class Accuracy(EvaluationBase):
'multilabel'. 'classification' means the dataset label is single. 'multilabel' means the dataset has multiple
labels. Default: 'classification'.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> import mindspore

View File

@ -29,6 +29,9 @@ class MAE(Metric):
where :math:`n` is batch size.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> import mindspore
@ -102,6 +105,9 @@ class MSE(Metric):
where :math:`n` is batch size.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> import mindspore

View File

@ -30,7 +30,10 @@ class Fbeta(Metric):
{(1+\beta^2) \cdot true\_positive +\beta^2 \cdot false\_negative + false\_positive}
Args:
beta (Union[float, int]): Beta coefficient in the F measure.
beta (Union[float, int]): Beta coefficient in the F measure. `beta` should be greater than 0.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np

View File

@ -24,6 +24,9 @@ class Loss(Metric):
.. math::
loss = \frac{\sum_{k=1}^{n}loss_k}{n}
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> import mindspore

View File

@ -58,9 +58,7 @@ class MeanSurfaceDistance(Metric):
Examples:
>>> import numpy as np
>>> from mindspore import nn, Tensor
>>>
>>> x = Tensor(np.array([[3, 0, 1], [1, 3, 0], [1, 0, 2]]))
>>> y = Tensor(np.array([[0, 2, 1], [1, 2, 1], [0, 0, 1]]))
>>> metric = nn.MeanSurfaceDistance(symmetric=False, distance_metric="euclidean")

View File

@ -28,6 +28,7 @@ def rearrange_inputs(func):
This decorator is currently applied on the `update` of :class:`mindspore.nn.Metric`.
Examples:
>>> from mindspore.nn import rearrange_inputs
>>> class RearrangeInputsExample:
... def __init__(self):
... self._indexes = None

View File

@ -28,8 +28,8 @@ class Perplexity(Metric):
PP(W)=P(w_{1}w_{2}...w_{N})^{-\frac{1}{N}}=\sqrt[N]{\frac{1}{P(w_{1}w_{2}...w_{N})}}
Args:
ignore_label (int): Index of an invalid label to be ignored when counting. If set to `None`, it will include all
entries. Default: None.
ignore_label (Union[int, None]): Index of an invalid label to be ignored when counting. If set to `None`,
it will include all entries. Default: None.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
@ -40,7 +40,6 @@ class Perplexity(Metric):
Examples:
>>> import numpy as np
>>> from mindspore import nn, Tensor
>>>
>>> x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]))
>>> y = Tensor(np.array([1, 0, 1]))
>>> metric = nn.Perplexity(ignore_label=None)

View File

@ -39,6 +39,9 @@ class Precision(EvaluationBase):
Args:
eval_type (str): 'classification' or 'multilabel' are supported. Default: 'classification'.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> from mindspore import nn, Tensor

View File

@ -40,6 +40,9 @@ class Recall(EvaluationBase):
eval_type (str): 'classification' or 'multilabel' are supported. Default: 'classification'.
Default: 'classification'.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> from mindspore import nn, Tensor

View File

@ -32,6 +32,9 @@ class TopKCategoricalAccuracy(Metric):
TypeError: If `k` is not int.
ValueError: If `k` is less than 1.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> from mindspore import nn, Tensor
@ -108,6 +111,9 @@ class Top1CategoricalAccuracy(TopKCategoricalAccuracy):
Calculates the top-1 categorical accuracy. This class is a specialized class for TopKCategoricalAccuracy.
Refer to :class:`TopKCategoricalAccuracy` for more details.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> from mindspore import nn, Tensor
@ -131,6 +137,9 @@ class Top5CategoricalAccuracy(TopKCategoricalAccuracy):
Calculates the top-5 categorical accuracy. This class is a specialized class for TopKCategoricalAccuracy.
Refer to :class:`TopKCategoricalAccuracy` for more details.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> from mindspore import nn, Tensor

View File

@ -296,7 +296,7 @@ class Adam(Optimizer):
``Ascend`` ``GPU`` ``CPU``
Examples:
>>>from mindspore import nn, Model
>>> from mindspore import nn, Model
>>>
>>> net = Net()
>>> #1) All parameters use the same learning rate and weight decay
@ -476,7 +476,7 @@ class AdamWeightDecay(Optimizer):
``Ascend`` ``GPU`` ``CPU``
Examples:
>>>from mindspore import nn, Model
>>> from mindspore import nn, Model
>>>
>>> net = Net()
>>> #1) All parameters use the same learning rate and weight decay

View File

@ -692,12 +692,20 @@ class ParameterUpdate(Cell):
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> import mindspore
>>> from mindspore import nn, Tensor
>>> network = nn.Dense(3, 4)
>>> param = network.parameters_dict()['weight']
>>> update = nn.ParameterUpdate(param)
>>> update.phase = "update_param"
>>> weight = Tensor(np.arange(12).reshape((4, 3)), mindspore.float32)
>>> output = update(weight)
>>> print(output)
[[ 0. 1. 2.]
[ 3. 4. 5.]
[ 6. 7. 8.]
[ 9. 10. 11.]]
"""
def __init__(self, param):

View File

@ -996,6 +996,7 @@ class Model:
Tensor, array(s) of predictions.
Examples:
>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore import Model, Tensor
>>>

View File

@ -222,7 +222,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
integrated_save (bool): Whether to integrated save in automatic model parallel scene. Default: True
async_save (bool): Whether to open an independent thread to save the checkpoint file. Default: False
append_dict (dict): Additional information that needs to be saved. The key of dict must be str,
the value of dict must be one of int float and bool. Default: None
the value of dict must be one of int, float or bool. Default: None
enc_key (Union[None, bytes]): Byte type key used for encryption. If the value is None, the encryption
is not required. Default: None.
enc_mode (str): This parameter is valid only when enc_key is not set to None. Specifies the encryption
@ -734,7 +734,7 @@ def _fill_param_into_net(net, parameter_list):
def export(net, *inputs, file_name, file_format='AIR', **kwargs):
"""
Export the mindspore network into an offline model in the specified format.
Export the MindSpore network into an offline model in the specified format.
Note:
1. When exporting AIR, ONNX format, the size of a single tensor can not exceed 2GB.