fix rewrite doc and add code examples

This commit is contained in:
yuzhenhua 2022-11-05 10:34:12 +08:00
parent 0b25747876
commit f0c2fdc0cc
6 changed files with 306 additions and 76 deletions

View File

@ -4,48 +4,49 @@ mindspore.rewrite
.. py:class:: mindspore.rewrite.SymbolTree(handler: SymbolTreeImpl)
SymbolTree通常对应于网络的forward方法
SymbolTree通常对应于网络的前向计算过程
参数:
- **network** (Cell) - 要重写的网络。现在只支持Cell类型的网络
- **handler** (SymbolTreeImpl) - SymbolTree内部实现实例
异常:
- **RuntimeError** - 如果 `network` 不是Cell对象。
- **RuntimeError** - 如果 `network` 中包含不支持解析和优化的ast节点类型。
- **RuntimeError** - `network` 不是Cell对象。
- **RuntimeError** - `network` 中包含不支持解析和优化的ast节点类型。
.. py:method:: mindspore.rewrite.SymbolTree.after(node: Node)
获取插入位置,位置为 `node` 之后。
返回值用于指示插入节点的位置,它指示在源代码中的位置,而不是在拓扑顺序中的位置。我们不需要关心 `Position` 是什么,只需将其视为处理程序并将其用作 `SymbolTree` 的插入接口的参数即可
返回值用于指示插入节点的位置,它指示在源代码中的位置,而不是在拓扑顺序中的位置。不需要关心 Position是什么只需将其视为处理程序并将其用作SymbolTree的插入接口的参数。
参数:
- **node** (Node) - 指定插入位置在哪个节点之后。
- **node** (Node) - 指定插入位置在哪个节点之后可以是Node或者Node的名称
返回:
Position指定插入节点的位置。
异常:
- **TypeError** - 如果参数不是Node类型。
- **TypeError** - 参数不是Node类型。
.. py:method:: mindspore.rewrite.SymbolTree.before(node: Node)
与after的区别是该接口返回的位置为 `node` 之前。
返回值用于指示插入节点的位置,它指示在源代码中的位置,而不是在拓扑顺序中的位置。不需要关心 `Position` 是什么,只需将其视为处理程序并将其用作 `SymbolTree` 的插入接口的参数。
参数:
- **node** (Node) - 指定插入位置在哪个节点之前。
- **node** (Node) - 指定插入位置在哪个节点之前可以是Node或者Node的名称
返回:
Position指定插入节点的位置。
异常:
- **TypeError** - 如果参数不是Node类型。
- **TypeError** - 参数不是Node类型。
.. py:method:: mindspore.rewrite.SymbolTree.create(network)
根据传入的 `network` 创建一个SymbolTree对象。
根据传入的 `network` 创建SymbolTree对象。
参数:
- **network** (Cell) - 重写的网络。现在只支持Cell类型的网络。
- **network** (Cell) - 重写的网络。
返回:
SymbolTree基于 `network` 创建的符号树。
@ -82,7 +83,7 @@ mindspore.rewrite
删除SymbolTree中的一个节点。被删除的节点必须不被其他节点依赖。
参数:
- **node** (Node) - 被删除的节点。
- **node** (Node) - 被删除的节点。可以是Node或者Node的名称。
返回:
如果 `node` 属于当前的SymbolTree则返回被删除节点。否则返回None。
@ -99,7 +100,7 @@ mindspore.rewrite
.. py:method:: mindspore.rewrite.SymbolTree.get_handler()
获取SymbolTree所对应的实现句柄
获取SymbolTree对应实现的handle
返回:
SymbolTree对象。
@ -129,7 +130,7 @@ mindspore.rewrite
如果找到则返回结果,否则返回 `None`
异常:
- **TypeError** - 如果参数不是Node类型。
- **TypeError** - 如果 `node_name` 不是Node类型。
.. py:method:: mindspore.rewrite.SymbolTree.get_saved_file_name()
@ -217,28 +218,27 @@ mindspore.rewrite
节点是表达网络中源代码的一种数据结构。
在大多数情况下Node表示一个向前计算的的运算它可以是Cell的实例、Primitive的实例或可调用的方法。
下面提到的NodeImpl是Node的实现它不是Rewrite的接口。Rewrite建议调用Node的特定 `create` 方法来实例化Node的实例例如 `create_call_cell`而不是直接调用Node的构造函数所以不要关心NodeImpl是什么只需要看做一个句柄即可。
在大多数情况下Node表示前向计算的的运算它可以是Cell的实例、Primitive的实例或可调用的方法。
参数:
- **node** (NodeImpl) - SymbolTree中节点的具体实现类的实例
- **node** (NodeImpl) - `NodeImpl` 的handle。NodeImpl是Node的实现不是Rewrite的接口。Rewrite建议调用Node的特定 `create` 方法来实例化Node的实例例如 `create_call_cell`而不直接调用Node的构造函数不需关心NodeImpl是什么只需作为handle看待
.. py:method:: mindspore.rewrite.Node.create_call_cell(cell: Cell, targets: [Union[ScopedValue, str]], args: [ScopedValue] = None, kwargs: {str: ScopedValue}=None, name: str = "", is_sub_net: bool = False)
:staticmethod:
通过该接口可以根据 `cell` 对象创建一个Node实例。节点对应的源代码格式 ``targets = self.name(*args, **kwargs)``
通过该接口可以根据 `cell` 对象创建一个Node实例。节点对应的源代码格式
``targets = self.name(*args, **kwargs)``
参数:
- **cell** (Cell) - 该节点对应的前向计算的Cell对象。
- **targets** (list[ScopedValue]) - 表示输出名称。在源代码中作为节点的输出。Rewrite将在插入节点时检查并确保每个目标的唯一性。
- **args** (list[ScopedValue]) - 该节点的参数名称。用作源代码中代码语句的参数。默认为None表示 `cell` 没有参数输入。Rewrite将在插入节点时检查并确保每个 `arg` 的唯一性。
- **kwargs** (dict) - 键的类型必须是str值的类型必须是ScopedValue。用来说明带有关键字的形参的输入参数名称。输入名称在源代码中作为语句表达式中的 `kwargs`默认为None表示 `cell` 没有 `kwargs` 输入。Rewrite将在插入节点时检查并确保每个 `kwarg` 的唯一性。
- **name** (str) - 表示节点的名称。用作源代码中的字段名称。默认为无。当名称为无时ReWrite将根据 `target` 生成一个默认名称。Rewrite将在插入节点时检查并确保名称的唯一性。
- **is_sub_net** (bool) - 表示 `cell` 是否是一个网络。如果 `is_sub_net` 为真Rewrite将尝试将 `cell` 解析为TreeNode否则为CallCell节点。默认False。
- **args** (list[ScopedValue]) - 该节点的参数名称。用作源代码中代码语句的参数。表示 `cell` 没有参数输入。Rewrite将在插入节点时检查并确保每个 `arg` 的唯一性。默认值None。
- **kwargs** (dict) - 键的类型必须是str值的类型必须是ScopedValue。用来说明带有关键字的形参的输入参数名称。输入名称在源代码中作为语句表达式中的 `kwargs`。表示 `cell` 没有 `kwargs` 输入。Rewrite将在插入节点时检查并确保每个 `kwarg` 的唯一性。默认值None。
- **name** (str) - 表示节点的名称。用作源代码中的字段名称。当名称为无时ReWrite将根据 `target` 生成一个默认名称。Rewrite将在插入节点时检查并确保名称的唯一性。默认值None。
- **is_sub_net** (bool) - 表示 `cell` 是否是一个网络。如果 `is_sub_net` 为真Rewrite将尝试将 `cell` 解析为TreeNode否则为CallCell节点。默认值:False。
返回:
一个Node实例。
Node实例。
异常:
- **TypeError** - 如果参数 `cell` 不是Cell类型。
@ -253,7 +253,7 @@ mindspore.rewrite
- 当前节点的 `node_type``CallCell``CallPrimitive``Tree` 时,返回值对应于 ast.Call 的 `args`,表示调用 `cell-op``primitive-op``forward` 方法的参数。
- 当前节点的 `node_type``Input` 时,返回值为函数参数的默认值。
- 当前节点的 `node_type``Output` 时,返回值对应网络的返回值。
- 当前节点的 `node_type``Output` 时,返回值网络的返回值。
- 当前节点的 `node_type``Python` 时,没有实际含义,可以忽略。
返回:
@ -281,14 +281,14 @@ mindspore.rewrite
.. py:method:: mindspore.rewrite.Node.get_handler()
获取节点具体实现的句柄
获取节点具体实现的handle
返回:
返回NodeImpl的实例。
.. py:method:: mindspore.rewrite.Node.get_inputs()
获取当前节点的输入节点。
获取当前节点的拓扑序的输入节点。
返回:
Node的实例列表。
@ -309,9 +309,9 @@ mindspore.rewrite
获取当前节点对应的 `operation` 实例类型。
- 如果当前节点的 `node_type``CallCell`,该节点是一个Cell对象。
- 如果当前节点的 `node_type``CallPrimitive`,该节点的是一个Primitive对象。
- 如果当前节点的 `node_type``Tree`,该节点的类型是一个网络。
- 如果当前节点的 `node_type``CallCell`该节点是Cell对象。
- 如果当前节点的 `node_type``CallPrimitive`该节点的是Primitive对象。
- 如果当前节点的 `node_type``Tree`,该节点的类型是网络。
- 如果当前节点的 `node_type``Python``Input``Output``CallMethod`该节点的类型为NoneType。
返回:
@ -332,7 +332,7 @@ mindspore.rewrite
获取当前节点的名称。当节点被插入到SymbolTree时节点的名称在SymbolTree中应该是唯一的。
返回:
str节点的名称。
节点的名称类型为str
.. py:method:: mindspore.rewrite.Node.get_next()
@ -385,14 +385,14 @@ mindspore.rewrite
- **TypeError** - 如果参数 `index` 不是int类型。
- **TypeError** - 如果参数 `arg` 不是str或者ScopedValue类型。
.. py:method:: mindspore.rewrite.Node.set_arg_by_node(arg_idx: int, src_node: 'Node', out_idx: Optional[int] = None)
.. py:method:: mindspore.rewrite.Node.set_arg_by_node(arg_idx: int, src_node: Node, out_idx: Optional[int] = None)
将另一个节点设置为当前节点的输入。
参数:
- **arg_idx** (int) - 要设置的参数索引。
- **src_node** (Node) - 作为输入的节点。
- **out_idx** (intoptional) - 指定输入节点的哪个输出作为当前节点输入,默认是None则取第一个输出。
- **src_node** (Node) - 输入的节点。
- **out_idx** (intoptional) - 指定输入节点的哪个输出作为当前节点输入,则取第一个输出。默认值None。
异常:
- **RuntimeError** - 如果 `src_node` 不属于当前的SymbolTree。
@ -402,7 +402,7 @@ mindspore.rewrite
- **TypeError** - 如果参数 `src_node` 不是Node类型。
- **TypeError** - 如果参数 `out_idx` 不是int类型。
- **ValueError** - 如果参数 `out_idx` 超出了 `src_node` 的输出数量。
- **ValueError** - 如果参数 `src_node` `out_idx` 为None或者没有给 `out_idx` 赋值时,有多个输出。
- **ValueError** - 当 `out_idx` 为None或者没有给 `out_idx` 赋值时,参数 `src_node` 有多个输出。
.. py:method:: mindspore.rewrite.Node.set_attribute(key: str, value)
@ -432,21 +432,21 @@ mindspore.rewrite
ScopedValue表示具有完整范围的值。
ScopedValue用于表示一个左值,如赋值语句的目标,或可调用对象,如调用语句的 `func`,或右值,如赋值语句的 `args``kwargs`
ScopedValue用于表示左值如赋值语句的目标或可调用对象如调用语句的 `func`,或右值,如赋值语句的 `args``kwargs`
参数:
- **arg_type** (ValueType) - 表示当前值的类型。
- **scope** (str) - 一个字符串表示当前值的范围。以"self.var1"为例这个var1的作用域是"self"。
- **value** - 当前ScopedValue中保存的值。值的类型对应于 `arg_type`
- **arg_type** (ValueType) - 当前值的类型。
- **scope** (str) - 字符串表示当前值的范围。以"self.var1"为例这个var1的作用域是"self"。默认值:""。
- **value** - 当前ScopedValue中保存的值。值的类型对应于 `arg_type`默认值None。
.. py:method:: mindspore.rewrite.ScopedValue.create_name_values(names: Union[list, tuple], scopes: Union[list, tuple] = None)
:staticmethod:
创建一个ScopedValue的列表。
创建ScopedValue的列表。
参数:
- **names** (list[str] or tuple[str]) str 的列表或元组表示引用变量的名称
- **scopes** (list[str] or tuple[str]) str 的列表或元组表示引用变量的范围默认值None表示没有指定作用范围
- **names** (list[str] or tuple[str]) 引用变量的名称,类型为str的列表或元组。
- **scopes** (list[str] or tuple[str]) 引用变量的范围类型为str的列表或元组。表示没有指定作用范围。默认值None
返回:
ScopedValue的实例列表。
@ -462,7 +462,7 @@ mindspore.rewrite
参数:
- **name** (str) 表示变量的字符串。
- **scope** (str) 表示变量范围的字符串,默认值为空字符串,表示没有指定作用范围。
- **scope** (str) 表示变量范围的字符串,表示没有指定作用范围。默认值:空字符串。
返回:
ScopedValue的实例。
@ -490,11 +490,11 @@ mindspore.rewrite
.. py:class:: mindspore.rewrite.PatternEngine(pattern: Union[PatternNode, List], replacement: Replacement = None)
PatternEngine实现了如何通过PattenNode修改SymbolTree。
PatternEngine通过PattenNode修改SymbolTree。
参数:
- **pattern** (Union[PatternNode, List]) - PatternNode的实例或用于构造 `Pattent` 的Cell类型列表。
- **replacement** (callable) - 生成新节点的接口实现如果为None则不进行任何匹配操作
- **replacement** (callable) - 生成新节点的接口实现。
.. py:method:: mindspore.rewrite.PatternEngine.apply(stree: SymbolTree)
@ -525,12 +525,12 @@ mindspore.rewrite
参数:
- **pattern_node_name** (str) - 节点名称。
- **match_type** (Type) - 当前节点的匹配类型。
- **inputs** (list[PatternNode]) - 当前节点的输入节点。
- **match_type** (Type) - 当前节点的匹配类型。默认值None。
- **inputs** (list[PatternNode]) - 当前节点的输入节点。默认值None。
.. py:method:: mindspore.rewrite.PatternNode.add_input(node)
为当前节点添加一个输入。
为当前节点添加输入。
参数:
- **node** (PatternNode) - 新增的输入节点。
@ -541,10 +541,10 @@ mindspore.rewrite
.. py:method:: mindspore.rewrite.PatternNode.create_pattern_from_list(type_list: [])
:staticmethod:
使用一个类型的列表来创建一个Pattern。
使用类型的列表来创建Pattern。
参数:
- **type_list** (list[type]) - 类型列表当前支持Cell和Primitive
- **type_list** (list[type]) - 类型列表。
返回:
根据列表生成的模式的根节点。
@ -555,7 +555,7 @@ mindspore.rewrite
.. py:method:: mindspore.rewrite.PatternNode.create_pattern_from_node(node: Node)
:staticmethod:
根据一个节点及其输入创建一个Pattern。
根据节点及其输入创建Pattern。
参数:
- **node** (Node) - 要修改的节点。
@ -635,7 +635,7 @@ mindspore.rewrite
参数:
- **pattern** (PatternNode) - 当前模式的根节点。
- **is_chain_pattern** (bool) - 标记模式是链模式或树模式。
- **is_chain_pattern** (bool) - 标记,标记模式是链模式或树模式。
- **matched** (OrderedDict) - 匹配结果,从名称映射到节点的字典。
返回:
@ -645,7 +645,7 @@ mindspore.rewrite
TreeNodeHelper用于在从Tree类型节点获取 `symbol_tree` 时打破循环引用。
TreeNodeHelper提供了一个静态方法 `get_sub_tree` 用于从Tree类型节点获取 `symbol_tree`
TreeNodeHelper提供了静态方法 `get_sub_tree` 用于从Tree类型节点获取 `symbol_tree`
.. py:method:: mindspore.rewrite.TreeNodeHelper.get_sub_tree(node: Node)
:staticmethod:
@ -653,11 +653,11 @@ mindspore.rewrite
获取Tree类型节点的 `symbol_tree`
参数:
- **node** (Node) - 一个可以持有子符号树的节点。
- **node** (Node) - 可以持有子符号树的节点。
返回:
Tree节点中的SymbolTree对象。注意节点的 `symbol_tree` 可能是None在这种情况下方法将返回None。
异常:
- **RuntimeError** - 如果参数 `node` `node_type` 不是Tree类型。
- **RuntimeError** - 如果参数 `node` 不是 NodeType.Tree类型。
- **TypeError** - 如果参数 `node` 不是Node类型实例。

View File

@ -32,13 +32,11 @@ class Node:
For the most part, Node represents an operator invoking in forward which could be an instance of `Cell`, an instance
of `Primitive` or a callable method.
`NodeImpl` mentioned below is implementation of `Node` which is not an interface of Rewrite. Rewrite recommend
invoking specific create method of `Node` to instantiate an instance of Node such as `create_call_cell` rather than
invoking constructor of `Node` directly, so don't care about what is `NodeImpl` and use its instance just as a
handler.
Args:
node (NodeImpl): A handler of `NodeImpl`.
node (NodeImpl): A handler of `NodeImpl`. `NodeImpl` mentioned below is implementation of `Node` which is not
an interface of Rewrite. Rewrite recommend invoking specific create method of `Node` to instantiate an instance
of Node such as `create_call_cell` rather than invoking constructor of `Node` directly, so don't care about
what is `NodeImpl` and use its instance just as a handler.
"""
def __init__(self, node: NodeImpl):
@ -142,6 +140,14 @@ class Node:
Returns:
A list of nodes represents users.
Examples:
>>> from mindspore.rewrite import SymbolTree
>>> from lenet import Lenet
>>> net = Lenet()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> users = node.get_users()
"""
belong_symbol_tree: SymbolTreeImpl = self._node.get_belong_symbol_tree()
if belong_symbol_tree is None:
@ -164,6 +170,14 @@ class Node:
Raises:
TypeError: If `index` is not a `int` number.
TypeError: If the type of `arg` is not in [`ScopedValue`, `str`].
Examples:
>>> from mindspore.rewrite import SymbolTree
>>> from lenet import Lenet
>>> net = Lenet()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> node.set_arg(0, "x")
"""
Validator.check_value_type("index", index, [int], "Node")
Validator.check_value_type("arg", arg, [ScopedValue, str], "Node")
@ -175,7 +189,7 @@ class Node:
def set_arg_by_node(self, arg_idx: int, src_node: 'Node', out_idx: Optional[int] = None):
"""
Set argument of current node by another `Node`.
Set argument of current node by another Node.
Args:
arg_idx (int): Indicate which input being modified.
@ -192,6 +206,15 @@ class Node:
TypeError: If `out_idx` is not a `int` number.
ValueError: If `out_idx` is out of range.
ValueError: If `src_node` has multi-outputs while `out_idx` is None or `out_idx` is not offered.
Examples:
>>> from mindspore.rewrite import SymbolTree
>>> from lenet import Lenet
>>> net = Lenet()
>>> stree = SymbolTree.create(net)
>>> src_node = stree.get_node("conv1")
>>> dst_node = stree.get_node("conv2")
>>> dst_node.set_arg_by_node(0, src_node)
"""
Validator.check_value_type("arg_idx", arg_idx, [int], "Node")
Validator.check_value_type("src_node", src_node, [Node], "Node")
@ -216,6 +239,14 @@ class Node:
Returns:
A list of instances of ScopedValue as targets of node.
Examples:
>>> from mindspore.rewrite import SymbolTree
>>> from lenet import Lenet
>>> net = Lenet()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> targets = node.get_targets()
"""
return self._node.get_targets()
@ -227,6 +258,14 @@ class Node:
Returns:
A string as name of node.
Examples:
>>> from mindspore.rewrite import SymbolTree
>>> from lenet import Lenet
>>> net = Lenet()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> name = node.get_name()
"""
return self._node.get_name()
@ -236,6 +275,14 @@ class Node:
Returns:
A NodeType as node_type of node.
Examples:
>>> from mindspore.rewrite import SymbolTree
>>> from lenet import Lenet
>>> net = Lenet()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> node_type = node.get_node_type()
"""
return self._node.get_node_type()
@ -272,14 +319,22 @@ class Node:
"""
Get the arguments of current node.
- When node_type of current node is `CallCell`, `CallPrimitive` or `Tree`, arguments are corresponding to args
- When `node_type` of current node is `CallCell`, `CallPrimitive` or `Tree`, arguments are corresponding to args
of ast.Call which represents arguments to invoke forward method of cell-op or primitive-op.
- When node_type of current node is `Input`, arguments represents default-value of argument of function.
- When node_type of current node is `Output`, arguments represents return values.
- When node_type of current node is `Python`, arguments are don't-care.
- When `node_type` of current node is `Input`, arguments represents default-value of argument of function.
- When `node_type` of current node is `Output`, arguments represents the return values of network.
- When `node_type` of current node is `Python`, arguments are don't-care.
Returns:
A list of instances of `ScopedValue`.
Examples:
>>> from mindspore.rewrite import SymbolTree
>>> from lenet import Lenet
>>> net = Lenet()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> args = node.get_args()
"""
return self._node.get_args()
@ -293,6 +348,14 @@ class Node:
Returns:
A dict of str to instance of `ScopedValue`.
Examples:
>>> from mindspore.rewrite import SymbolTree
>>> from lenet import Lenet
>>> net = Lenet()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> kwargs = node.get_kwargs()
"""
return self._node.get_kwargs()
@ -306,6 +369,14 @@ class Node:
Raises:
TypeError: If `key` is not a `str`.
Examples:
>>> from mindspore.rewrite import SymbolTree
>>> from lenet import Lenet
>>> net = Lenet()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> node.set_attribute("channel", 3)
"""
Validator.check_value_type("key", key, [str], "Node attribute")
self._node.set_attribute(key, value)
@ -327,7 +398,7 @@ class Node:
key (str): Key of attribute.
Returns:
A object as attribute.
A object as attribute, can be any type.
Raises:
TypeError: If `key` is not a `str`.

View File

@ -33,8 +33,8 @@ class PatternNode:
Args:
pattern_node_name (str): Name of current node.
match_type (Type): A type represents what type would be matched of current node.
inputs (list[PatternNode]): Input nodes of current node.
match_type (Type): A type represents what type would be matched of current node. Default: None.
inputs (list[PatternNode]): Input nodes of current node. Default: None.
"""
def __init__(self, pattern_node_name: str, match_type: Type = Type[None], inputs: ['PatternNode'] = None):
@ -166,13 +166,13 @@ class PatternNode:
def name(self) -> str:
"""
Getter of name.
Getter of PatternNode name.
"""
return self._name
def type(self):
"""
Getter of type.
Getter of PatternNode type.
"""
return self._type
@ -192,8 +192,17 @@ class VarNode(PatternNode):
class Replacement(abc.ABC):
"""
Interface of replacement function.
"""
Examples:
>>> from mindspore.rewrite import Replacement, Node
>>> from mindspore.nn import nn
>>> class BnReplacement(Replacement):
... def build(self, pattern, is_chain_pattern: bool, matched):
... bn_node: Node = matched.get(pattern.name())
... conv = nn.Conv2d(16, 16, 3)
... conv_node = Node.create_call_cell(conv, ['x1'], bn_node.get_args(), bn_node.get_kwargs())
... return [conv_node]
"""
@abc.abstractmethod
def build(self, pattern: PatternNode, is_chain_pattern: bool, matched: OrderedDict) -> [Node]:
"""

View File

@ -73,6 +73,10 @@ class ScopedValue:
Returns:
An instance of `ScopedValue`.
Examples:
>>> from mindspore.rewrite import ScopedValue
>>> variable = ScopedValue.create_variable_value(2)
"""
if isinstance(value, int):
return cls(ValueType.IntValue, "", value)
@ -108,6 +112,10 @@ class ScopedValue:
Raises:
TypeError: If `name` is not `str`.
TypeError: If `scope` is not `str`.
Examples:
>>> from mindspore.rewrite import ScopedValue
>>> variable = ScopedValue.create_naming_value("conv", "self")
"""
Validator.check_value_type("name", name, [str], "ScopedValue")
Validator.check_value_type("scope", scope, [str], "ScopedValue")
@ -129,6 +137,10 @@ class ScopedValue:
RuntimeError: If the length of names is not equal to the length of scopes when scopes are not None.
TypeError: If `names` is not `list` or `tuple` and name in `names` is not `str`.
TypeError: If `scopes` is not `list` or `tuple` and scope in `scopes` is not `str`.
Examples:
>>> from mindspore.rewrite import ScopedValue
>>> variables = ScopedValue.create_name_values(["z", "z_1"]), name="subnet")
"""
Validator.check_element_type_of_iterable("names", names, [str], "ScopedValue")
if scopes is not None:

View File

@ -99,6 +99,14 @@ class SymbolTree:
TypeError: If the type of `targets` is not str.
TypeError: If arg in `args` is not ParamType.
TypeError: If key of `kwarg` is not a str or value of kwarg in `kwargs` is not ParamType.
Examples:
>>> from mindspore.rewrite import SymbolTree
>>> from lenet import Lenet
>>> net = Lenet()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> new_node = stree.create_call_function(F.abs, ["x"], node)
"""
Validator.check_value_type("func", func, [FunctionType], "SymbolTree node")
Validator.check_element_type_of_iterable("targets", targets, [str], "SymbolTree node")
@ -119,6 +127,13 @@ class SymbolTree:
Returns:
An instance of `SymbolTree`.
Examples:
>>> from mindspore.rewrite import SymbolTree
>>> from lenet import Lenet
>>> net = Lenet()
>>> stree = SymbolTree.create(net)
>>> handler = stree.get_handler()
"""
return self._symbol_tree
@ -128,6 +143,14 @@ class SymbolTree:
Returns:
A generator for node of current `SymbolTree`.
Examples:
>>> from mindspore.rewrite import SymbolTree
>>> from lenet import Lenet
>>> net = Lenet()
>>> stree = SymbolTree.create(net)
>>> for node in stree.nodes():
... node.set_attribute("channel", 3)
"""
for node in self._symbol_tree.nodes():
yield Node(node)
@ -144,6 +167,13 @@ class SymbolTree:
Raises:
TypeError: If `node_name` is not `str`.
Examples:
>>> from mindspore.rewrite import SymbolTree
>>> from lenet import Lenet
>>> net = Lenet()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
"""
Validator.check_value_type("node_name", node_name, [str], "SymbolTree")
node_impl = self._symbol_tree.get_node(node_name)
@ -157,6 +187,13 @@ class SymbolTree:
Returns:
[Node], the node list of the current `Symboltree`.
Examples:
>>> from mindspore.rewrite import SymbolTree
>>> from lenet import Lenet
>>> net = Lenet()
>>> stree = SymbolTree.create(net)
>>> inputs = stree.get_inputs()
"""
return [Node(node_impl) for node_impl in self._symbol_tree.get_inputs()]
@ -176,6 +213,15 @@ class SymbolTree:
Raises:
TypeError: if `node` is not a `Node`.
Examples:
>>> from mindspore.rewrite import SymbolTree
>>> from lenet import Lenet
>>> net = Lenet()
>>> stree = SymbolTree.create(net)
>>> for node in stree.nodes():
... if node.get_name() == "conv1":
... position = stree.before(node)
"""
Validator.check_value_type("node", node, [Node], "SymbolTree")
return self._symbol_tree.before(node.get_handler())
@ -196,6 +242,15 @@ class SymbolTree:
Raises:
TypeError: If `node` is not a `Node`.
Examples:
>>> from mindspore.rewrite import SymbolTree
>>> from lenet import Lenet
>>> net = Lenet()
>>> stree = SymbolTree.create(net)
>>> for node in stree.nodes():
... if node.get_name() == "conv1":
... position = stree.after(node)
"""
Validator.check_value_type("node", node, [Node], "SymbolTree")
return self._symbol_tree.after(node.get_handler())
@ -218,6 +273,16 @@ class SymbolTree:
RuntimeError: If `position` is not belong to current `SymbolTree`.
TypeError: If `position` is not a `Position`.
TypeError: If `node` is not a `Node`.
Examples:
>>> from mindspore.rewrite import SymbolTree
>>> from lenet import Lenet
>>> net = Lenet()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> position = stree.after(node)
>>> new_node = stree.create_call_function(F.abs, ["x"], node)
>>> stree.insert(position, new_node)
"""
Validator.check_value_type("position", position, [Position], "SymbolTree")
Validator.check_value_type("node", node, [Node], "SymbolTree")
@ -235,6 +300,18 @@ class SymbolTree:
Raises:
TypeError: If `node` is not a `Node`.
Examples:
>>> from mindspore.rewrite import SymbolTree
>>> from lenet import Lenet
>>> net = Lenet()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> input_node = node.get_inputs()[0]
>>> output_nodes = node.get_users()
>>> for n in output_nodes:
... n.set_arg(0, "x")
>>> stree.erase_node(node)
"""
Validator.check_value_type("node", node, [Node], "SymbolTree")
return Node(self._symbol_tree.erase_node(node.get_handler()))
@ -264,9 +341,18 @@ class SymbolTree:
An instance of Node represents root of node_tree been replaced in.
Raises:
RuntimeError: Old node is isolated.
RuntimeError: Old node is not isolated.
TypeError: If `old_node` is not a `Node`.
TypeError: If `new_nodes` is not a `list` or node in `new_nodes` is not a `Node`.
Examples:
>>> from mindspore.rewrite import SymbolTree
>>> from lenet import Lenet
>>> net = Lenet()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> new_node = stree.create_call_function(F.abs, ["x"], node)
>>> stree.replace(node, [new_node])
"""
Validator.check_value_type("old_node", old_node, [Node], "SymbolTree")
Validator.check_element_type_of_iterable("new_nodes", new_nodes, [Node], "SymbolTree")
@ -288,6 +374,13 @@ class SymbolTree:
RuntimeError: If `index` is out of range.
TypeError: If `index` is not a `int` number.
TypeError: If `return_value` is not a `str`.
Examples:
>>> from mindspore.rewrite import SymbolTree
>>> from lenet import Lenet
>>> net = Lenet()
>>> stree = SymbolTree.create(net)
>>> stree.set_output(0, "x_10")
"""
Validator.check_value_type("index", index, [int], "SymbolTree")
Validator.check_value_type("return_value", return_value, [str], "SymbolTree")
@ -295,13 +388,20 @@ class SymbolTree:
def dump(self):
"""
Dump graph to console.
Print the ir map information corresponding to the network in 'SymbolTree' to the screen.
"""
self._symbol_tree.dump()
def print_node_tabulate(self):
"""
Print node information of graph.
Examples:
>>> from mindspore.rewrite import SymbolTree
>>> from lenet import Lenet
>>> net = Lenet()
>>> stree = SymbolTree.create(net)
>>> stree.print_node_tabulate()
"""
self._symbol_tree.print_node_tabulate()
@ -311,6 +411,13 @@ class SymbolTree:
Returns:
A str represents source code of modified network.
Examples:
>>> from mindspore.rewrite import SymbolTree
>>> from lenet import Lenet
>>> net = Lenet()
>>> stree = SymbolTree.create(net)
>>> stree.get_code()
"""
return self._symbol_tree.get_code()
@ -321,6 +428,13 @@ class SymbolTree:
Returns:
A network object.
Examples:
>>> from mindspore.rewrite import SymbolTree
>>> from lenet import Lenet
>>> net = Lenet()
>>> stree = SymbolTree.create(net)
>>> stree.get_network()
"""
return self._symbol_tree.get_network()
@ -330,16 +444,40 @@ class SymbolTree:
Args:
file_name (str): filename to be set.
Examples:
>>> from mindspore.rewrite import SymbolTree
>>> from lenet import Lenet
>>> net = Lenet()
>>> stree = SymbolTree.create(net)
>>> stree.set_saved_file_name("new_net")
"""
Validator.check_value_type("file_name", file_name, [str], "Saving network")
self._symbol_tree.set_saved_file_name(file_name)
def get_saved_file_name(self):
"""Gets the filename used to save the network."""
"""
Gets the filename used to save the network.
Examples:
>>> from mindspore.rewrite import SymbolTree
>>> from lenet import Lenet
>>> net = Lenet()
>>> stree = SymbolTree.create(net)x
>>> stree.set_saved_file_name("new_net")
>>> stree.get_saved_file_name()
"""
return self._symbol_tree.get_saved_file_name()
def save_network_to_file(self):
"""
Save the modified network to a file. Default file name is `network_define.py`.
Examples:
>>> from mindspore.rewrite import SymbolTree
>>> from lenet import Lenet
>>> net = Lenet()
>>> stree = SymbolTree.create(net)
>>> stree.save_network_to_file()
"""
self._symbol_tree.save_network_to_file()

View File

@ -36,7 +36,7 @@ class TreeNodeHelper:
Getting symbol_tree from a `Tree` type `Node`.
Args:
node (Node): A `Node` who may hold a sub-symbol_tree.
node (Node): A `Node` which may hold a sub-symbol_tree.
Returns:
An instance of SymbolTree represents sub-symbol_tree. Note that `node`'s symbol_tree maybe None, in this