!46254 fix doc of train and rewrite

Merge pull request !46254 from 于振华/code_docs_fix_model_1130
This commit is contained in:
i-robot 2022-11-30 07:31:09 +00:00 committed by Gitee
commit e28897656b
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 3 additions and 10 deletions

View File

@ -9,10 +9,6 @@ MindSpore的ReWrite功能用于修改网络前向的计算过程对网络进
参数:
- **handler** (SymbolTreeImpl) - SymbolTree内部实现实例。
异常:
- **RuntimeError** - `network` 不是Cell对象。
- **RuntimeError** - `network` 中包含不支持解析和优化的ast节点类型。
.. py:method:: mindspore.rewrite.SymbolTree.after(node: Node)
获取插入位置,位置为 `node` 之后。

View File

@ -16,9 +16,10 @@
- **metrics** (Union[dict, set]) - 用于模型评估的一组评价函数。例如:{'accuracy', 'recall'}。默认值None。
- **eval_network** (Cell) - 用于评估的神经网络。未定义情况下,`Model` 会使用 `network``loss_fn` 封装一个 `eval_network` 。默认值None。
- **eval_indexes** (list) - 在定义 `eval_network` 的情况下使用。如果 `eval_indexes` 为默认值None`Model` 会将 `eval_network` 的所有输出传给 `metrics` 。如果配置 `eval_indexes` ,必须包含三个元素,分别为损失值、预测值和标签在 `eval_network` 输出中的位置,此时,损失值将传给损失评价函数,预测值和标签将传给其他评价函数。推荐使用评价函数的 `mindspore.nn.Metric.set_indexes` 代替 `eval_indexes` 。默认值None。
- **amp_level** (str) - `mindspore.build_train_network` 的可选参数 `level` `level` 为混合精度等级,该参数支持["O0", "O2", "O3", "auto"]。默认值:"O0"。
- **amp_level** (str) - `mindspore.build_train_network` 的可选参数 `level` `level` 为混合精度等级,该参数支持["O0", "O1", "O2", "O3", "auto"]。默认值:"O0"。
- "O0": 不变化。
- "O1": 将白名单中的算子转为float16剩余算子保持float32。
- "O2": 将网络精度转为float16BatchNorm保持float32精度使用动态调整损失缩放系数loss scale的策略。
- "O3": 将网络精度包括BatchNorm转为float16不使用损失缩放策略。
- auto: 为不同处理器设置专家推荐的混合精度等级如在GPU上设为"O2"在Ascend上设为"O3"。该设置方式可能在部分场景下不适用,建议用户根据具体的网络模型自定义设置 `amp_level`

View File

@ -32,11 +32,7 @@ class SymbolTree:
A `SymbolTree` usually corresponding to forward method of a network.
Args:
network (Cell): Network to be rewritten. Only support `Cell`-type-network now.
Raises:
RuntimeError: If `network` is not a Cell.
RuntimeError: If there is any unsupported ast node type while parsing or optimizing.
handler (SymbolTreeImpl): SymbolTree internal implementation instance.
"""
def __init__(self, handler: SymbolTreeImpl):