forked from mindspore-Ecosystem/mindspore
!46254 fix doc of train and rewrite
Merge pull request !46254 from 于振华/code_docs_fix_model_1130
This commit is contained in:
commit
e28897656b
|
@ -9,10 +9,6 @@ MindSpore的ReWrite功能用于修改网络前向的计算过程,对网络进
|
|||
参数:
|
||||
- **handler** (SymbolTreeImpl) - SymbolTree内部实现实例。
|
||||
|
||||
异常:
|
||||
- **RuntimeError** - `network` 不是Cell对象。
|
||||
- **RuntimeError** - `network` 中包含不支持解析和优化的ast节点类型。
|
||||
|
||||
.. py:method:: mindspore.rewrite.SymbolTree.after(node: Node)
|
||||
|
||||
获取插入位置,位置为 `node` 之后。
|
||||
|
|
|
@ -16,9 +16,10 @@
|
|||
- **metrics** (Union[dict, set]) - 用于模型评估的一组评价函数。例如:{'accuracy', 'recall'}。默认值:None。
|
||||
- **eval_network** (Cell) - 用于评估的神经网络。未定义情况下,`Model` 会使用 `network` 和 `loss_fn` 封装一个 `eval_network` 。默认值:None。
|
||||
- **eval_indexes** (list) - 在定义 `eval_network` 的情况下使用。如果 `eval_indexes` 为默认值None,`Model` 会将 `eval_network` 的所有输出传给 `metrics` 。如果配置 `eval_indexes` ,必须包含三个元素,分别为损失值、预测值和标签在 `eval_network` 输出中的位置,此时,损失值将传给损失评价函数,预测值和标签将传给其他评价函数。推荐使用评价函数的 `mindspore.nn.Metric.set_indexes` 代替 `eval_indexes` 。默认值:None。
|
||||
- **amp_level** (str) - `mindspore.build_train_network` 的可选参数 `level` , `level` 为混合精度等级,该参数支持["O0", "O2", "O3", "auto"]。默认值:"O0"。
|
||||
- **amp_level** (str) - `mindspore.build_train_network` 的可选参数 `level` , `level` 为混合精度等级,该参数支持["O0", "O1", "O2", "O3", "auto"]。默认值:"O0"。
|
||||
|
||||
- "O0": 不变化。
|
||||
- "O1": 将白名单中的算子转为float16,剩余算子保持float32。
|
||||
- "O2": 将网络精度转为float16,BatchNorm保持float32精度,使用动态调整损失缩放系数(loss scale)的策略。
|
||||
- "O3": 将网络精度(包括BatchNorm)转为float16,不使用损失缩放策略。
|
||||
- auto: 为不同处理器设置专家推荐的混合精度等级,如在GPU上设为"O2",在Ascend上设为"O3"。该设置方式可能在部分场景下不适用,建议用户根据具体的网络模型自定义设置 `amp_level` 。
|
||||
|
|
|
@ -32,11 +32,7 @@ class SymbolTree:
|
|||
A `SymbolTree` usually corresponding to forward method of a network.
|
||||
|
||||
Args:
|
||||
network (Cell): Network to be rewritten. Only support `Cell`-type-network now.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If `network` is not a Cell.
|
||||
RuntimeError: If there is any unsupported ast node type while parsing or optimizing.
|
||||
handler (SymbolTreeImpl): SymbolTree internal implementation instance.
|
||||
"""
|
||||
|
||||
def __init__(self, handler: SymbolTreeImpl):
|
||||
|
|
Loading…
Reference in New Issue