fix rewrite doc and add code examples
This commit is contained in:
parent
0b25747876
commit
f0c2fdc0cc
|
@ -4,48 +4,49 @@ mindspore.rewrite
|
|||
|
||||
.. py:class:: mindspore.rewrite.SymbolTree(handler: SymbolTreeImpl)
|
||||
|
||||
SymbolTree通常对应于网络的forward方法。
|
||||
SymbolTree通常对应于网络的前向计算过程。
|
||||
|
||||
参数:
|
||||
- **network** (Cell) - 要重写的网络。现在只支持Cell类型的网络。
|
||||
- **handler** (SymbolTreeImpl) - SymbolTree内部实现实例。
|
||||
|
||||
异常:
|
||||
- **RuntimeError** - 如果 `network` 不是Cell对象。
|
||||
- **RuntimeError** - 如果 `network` 中包含不支持解析和优化的ast节点类型。
|
||||
- **RuntimeError** - `network` 不是Cell对象。
|
||||
- **RuntimeError** - `network` 中包含不支持解析和优化的ast节点类型。
|
||||
|
||||
.. py:method:: mindspore.rewrite.SymbolTree.after(node: Node)
|
||||
|
||||
获取插入位置,位置为 `node` 之后。
|
||||
返回值用于指示插入节点的位置,它指示在源代码中的位置,而不是在拓扑顺序中的位置。我们不需要关心 `Position` 是什么,只需将其视为处理程序并将其用作 `SymbolTree` 的插入接口的参数即可。
|
||||
返回值用于指示插入节点的位置,它指示在源代码中的位置,而不是在拓扑顺序中的位置。不需要关心 Position是什么,只需将其视为处理程序并将其用作SymbolTree的插入接口的参数。
|
||||
|
||||
参数:
|
||||
- **node** (Node) - 指定插入位置在哪个节点之后。
|
||||
- **node** (Node) - 指定插入位置在哪个节点之后,可以是Node或者Node的名称。
|
||||
|
||||
返回:
|
||||
Position,指定插入节点的位置。
|
||||
|
||||
异常:
|
||||
- **TypeError** - 如果参数不是Node类型。
|
||||
- **TypeError** - 参数不是Node类型。
|
||||
|
||||
.. py:method:: mindspore.rewrite.SymbolTree.before(node: Node)
|
||||
|
||||
与after的区别是,该接口返回的位置为 `node` 之前。
|
||||
返回值用于指示插入节点的位置,它指示在源代码中的位置,而不是在拓扑顺序中的位置。不需要关心 `Position` 是什么,只需将其视为处理程序并将其用作 `SymbolTree` 的插入接口的参数。
|
||||
|
||||
参数:
|
||||
- **node** (Node) - 指定插入位置在哪个节点之前。
|
||||
- **node** (Node) - 指定插入位置在哪个节点之前,可以是Node或者Node的名称。
|
||||
|
||||
返回:
|
||||
Position,指定插入节点的位置。
|
||||
|
||||
异常:
|
||||
- **TypeError** - 如果参数不是Node类型。
|
||||
- **TypeError** - 参数不是Node类型。
|
||||
|
||||
.. py:method:: mindspore.rewrite.SymbolTree.create(network)
|
||||
|
||||
根据传入的 `network` 创建一个SymbolTree对象。
|
||||
根据传入的 `network` 创建SymbolTree对象。
|
||||
|
||||
参数:
|
||||
- **network** (Cell) - 要重写的网络。现在只支持Cell类型的网络。
|
||||
- **network** (Cell) - 重写的网络。
|
||||
|
||||
返回:
|
||||
SymbolTree,基于 `network` 创建的符号树。
|
||||
|
@ -82,7 +83,7 @@ mindspore.rewrite
|
|||
删除SymbolTree中的一个节点。被删除的节点必须不被其他节点依赖。
|
||||
|
||||
参数:
|
||||
- **node** (Node) - 被删除的节点。
|
||||
- **node** (Node) - 被删除的节点。可以是Node或者Node的名称。
|
||||
|
||||
返回:
|
||||
如果 `node` 属于当前的SymbolTree则返回被删除节点。否则返回None。
|
||||
|
@ -99,7 +100,7 @@ mindspore.rewrite
|
|||
|
||||
.. py:method:: mindspore.rewrite.SymbolTree.get_handler()
|
||||
|
||||
获取SymbolTree所对应的实现句柄。
|
||||
获取SymbolTree对应实现的handle。
|
||||
|
||||
返回:
|
||||
SymbolTree对象。
|
||||
|
@ -129,7 +130,7 @@ mindspore.rewrite
|
|||
如果找到则返回结果,否则返回 `None`。
|
||||
|
||||
异常:
|
||||
- **TypeError** - 如果参数不是Node类型。
|
||||
- **TypeError** - 如果 `node_name` 不是Node类型。
|
||||
|
||||
.. py:method:: mindspore.rewrite.SymbolTree.get_saved_file_name()
|
||||
|
||||
|
@ -217,28 +218,27 @@ mindspore.rewrite
|
|||
|
||||
节点是表达网络中源代码的一种数据结构。
|
||||
|
||||
在大多数情况下,Node表示一个向前计算的的运算,它可以是Cell的实例、Primitive的实例或可调用的方法。
|
||||
|
||||
下面提到的NodeImpl是Node的实现,它不是Rewrite的接口。Rewrite建议调用Node的特定 `create` 方法来实例化Node的实例,例如 `create_call_cell`,而不是直接调用Node的构造函数,所以不要关心NodeImpl是什么,只需要看做一个句柄即可。
|
||||
在大多数情况下,Node表示前向计算的的运算,它可以是Cell的实例、Primitive的实例或可调用的方法。
|
||||
|
||||
参数:
|
||||
- **node** (NodeImpl) - SymbolTree中节点的具体实现类的实例。
|
||||
- **node** (NodeImpl) - `NodeImpl` 的handle。NodeImpl是Node的实现,不是Rewrite的接口。Rewrite建议调用Node的特定 `create` 方法来实例化Node的实例,例如 `create_call_cell`,而不直接调用Node的构造函数,不需关心NodeImpl是什么,只需作为handle看待。
|
||||
|
||||
.. py:method:: mindspore.rewrite.Node.create_call_cell(cell: Cell, targets: [Union[ScopedValue, str]], args: [ScopedValue] = None, kwargs: {str: ScopedValue}=None, name: str = "", is_sub_net: bool = False)
|
||||
:staticmethod:
|
||||
|
||||
通过该接口可以根据 `cell` 对象创建一个Node实例。节点对应的源代码格式: ``targets = self.name(*args, **kwargs)``。
|
||||
通过该接口可以根据 `cell` 对象创建一个Node实例。节点对应的源代码格式:
|
||||
``targets = self.name(*args, **kwargs)``。
|
||||
|
||||
参数:
|
||||
- **cell** (Cell) - 该节点对应的前向计算的Cell对象。
|
||||
- **targets** (list[ScopedValue]) - 表示输出名称。在源代码中作为节点的输出。Rewrite将在插入节点时检查并确保每个目标的唯一性。
|
||||
- **args** (list[ScopedValue]) - 该节点的参数名称。用作源代码中代码语句的参数。默认为None表示 `cell` 没有参数输入。Rewrite将在插入节点时检查并确保每个 `arg` 的唯一性。
|
||||
- **kwargs** (dict) - 键的类型必须是str,值的类型必须是ScopedValue。用来说明带有关键字的形参的输入参数名称。输入名称在源代码中作为语句表达式中的 `kwargs`。默认为None,表示 `cell` 没有 `kwargs` 输入。Rewrite将在插入节点时检查并确保每个 `kwarg` 的唯一性。
|
||||
- **name** (str) - 表示节点的名称。用作源代码中的字段名称。默认为无。当名称为无时,ReWrite将根据 `target` 生成一个默认名称。Rewrite将在插入节点时检查并确保名称的唯一性。
|
||||
- **is_sub_net** (bool) - 表示 `cell` 是否是一个网络。如果 `is_sub_net` 为真,Rewrite将尝试将 `cell` 解析为TreeNode,否则为CallCell节点。默认为False。
|
||||
- **args** (list[ScopedValue]) - 该节点的参数名称。用作源代码中代码语句的参数。表示 `cell` 没有参数输入。Rewrite将在插入节点时检查并确保每个 `arg` 的唯一性。默认值:None。
|
||||
- **kwargs** (dict) - 键的类型必须是str,值的类型必须是ScopedValue。用来说明带有关键字的形参的输入参数名称。输入名称在源代码中作为语句表达式中的 `kwargs`。表示 `cell` 没有 `kwargs` 输入。Rewrite将在插入节点时检查并确保每个 `kwarg` 的唯一性。默认值:None。
|
||||
- **name** (str) - 表示节点的名称。用作源代码中的字段名称。当名称为无时,ReWrite将根据 `target` 生成一个默认名称。Rewrite将在插入节点时检查并确保名称的唯一性。默认值:None。
|
||||
- **is_sub_net** (bool) - 表示 `cell` 是否是一个网络。如果 `is_sub_net` 为真,Rewrite将尝试将 `cell` 解析为TreeNode,否则为CallCell节点。默认值:False。
|
||||
|
||||
返回:
|
||||
一个Node实例。
|
||||
Node实例。
|
||||
|
||||
异常:
|
||||
- **TypeError** - 如果参数 `cell` 不是Cell类型。
|
||||
|
@ -253,7 +253,7 @@ mindspore.rewrite
|
|||
|
||||
- 当前节点的 `node_type` 为 `CallCell`、 `CallPrimitive` 或 `Tree` 时,返回值对应于 ast.Call 的 `args`,表示调用 `cell-op` 或 `primitive-op` 的 `forward` 方法的参数。
|
||||
- 当前节点的 `node_type` 为 `Input` 时,返回值为函数参数的默认值。
|
||||
- 当前节点的 `node_type` 为 `Output` 时,返回值对应网络的返回值。
|
||||
- 当前节点的 `node_type` 为 `Output` 时,返回值为网络的返回值。
|
||||
- 当前节点的 `node_type` 为 `Python` 时,没有实际含义,可以忽略。
|
||||
|
||||
返回:
|
||||
|
@ -281,14 +281,14 @@ mindspore.rewrite
|
|||
|
||||
.. py:method:: mindspore.rewrite.Node.get_handler()
|
||||
|
||||
获取节点具体实现的句柄。
|
||||
获取节点具体实现的handle。
|
||||
|
||||
返回:
|
||||
返回NodeImpl的实例。
|
||||
|
||||
.. py:method:: mindspore.rewrite.Node.get_inputs()
|
||||
|
||||
获取当前节点的输入节点。
|
||||
获取当前节点的拓扑序的输入节点。
|
||||
|
||||
返回:
|
||||
Node的实例列表。
|
||||
|
@ -309,9 +309,9 @@ mindspore.rewrite
|
|||
|
||||
获取当前节点对应的 `operation` 实例类型。
|
||||
|
||||
- 如果当前节点的 `node_type` 是 `CallCell`,该节点是一个Cell对象。
|
||||
- 如果当前节点的 `node_type` 是 `CallPrimitive`,该节点的是一个Primitive对象。
|
||||
- 如果当前节点的 `node_type` 是 `Tree`,该节点的类型是一个网络。
|
||||
- 如果当前节点的 `node_type` 是 `CallCell`,该节点是Cell对象。
|
||||
- 如果当前节点的 `node_type` 是 `CallPrimitive`,该节点的是Primitive对象。
|
||||
- 如果当前节点的 `node_type` 是 `Tree`,该节点的类型是网络。
|
||||
- 如果当前节点的 `node_type` 是 `Python`、 `Input`、 `Output`、 `CallMethod`,该节点的类型为NoneType。
|
||||
|
||||
返回:
|
||||
|
@ -332,7 +332,7 @@ mindspore.rewrite
|
|||
获取当前节点的名称。当节点被插入到SymbolTree时,节点的名称在SymbolTree中应该是唯一的。
|
||||
|
||||
返回:
|
||||
str,节点的名称。
|
||||
节点的名称,类型为str。
|
||||
|
||||
.. py:method:: mindspore.rewrite.Node.get_next()
|
||||
|
||||
|
@ -385,14 +385,14 @@ mindspore.rewrite
|
|||
- **TypeError** - 如果参数 `index` 不是int类型。
|
||||
- **TypeError** - 如果参数 `arg` 不是str或者ScopedValue类型。
|
||||
|
||||
.. py:method:: mindspore.rewrite.Node.set_arg_by_node(arg_idx: int, src_node: 'Node', out_idx: Optional[int] = None)
|
||||
.. py:method:: mindspore.rewrite.Node.set_arg_by_node(arg_idx: int, src_node: Node, out_idx: Optional[int] = None)
|
||||
|
||||
将另一个节点设置为当前节点的输入。
|
||||
|
||||
参数:
|
||||
- **arg_idx** (int) - 要设置的参数索引。
|
||||
- **src_node** (Node) - 作为输入的节点。
|
||||
- **out_idx** (int,optional) - 指定输入节点的哪个输出作为当前节点输入,默认是None,则取第一个输出。
|
||||
- **src_node** (Node) - 输入的节点。
|
||||
- **out_idx** (int,optional) - 指定输入节点的哪个输出作为当前节点输入,则取第一个输出。默认值:None。
|
||||
|
||||
异常:
|
||||
- **RuntimeError** - 如果 `src_node` 不属于当前的SymbolTree。
|
||||
|
@ -402,7 +402,7 @@ mindspore.rewrite
|
|||
- **TypeError** - 如果参数 `src_node` 不是Node类型。
|
||||
- **TypeError** - 如果参数 `out_idx` 不是int类型。
|
||||
- **ValueError** - 如果参数 `out_idx` 超出了 `src_node` 的输出数量。
|
||||
- **ValueError** - 如果参数 `src_node` 当 `out_idx` 为None或者没有给 `out_idx` 赋值时,有多个输出。
|
||||
- **ValueError** - 当 `out_idx` 为None或者没有给 `out_idx` 赋值时,参数 `src_node` 有多个输出。
|
||||
|
||||
.. py:method:: mindspore.rewrite.Node.set_attribute(key: str, value)
|
||||
|
||||
|
@ -432,21 +432,21 @@ mindspore.rewrite
|
|||
|
||||
ScopedValue表示具有完整范围的值。
|
||||
|
||||
ScopedValue用于表示:一个左值,如赋值语句的目标,或可调用对象,如调用语句的 `func`,或右值,如赋值语句的 `args` 和 `kwargs`。
|
||||
ScopedValue用于表示:左值,如赋值语句的目标,或可调用对象,如调用语句的 `func`,或右值,如赋值语句的 `args` 和 `kwargs`。
|
||||
|
||||
参数:
|
||||
- **arg_type** (ValueType) - 表示当前值的类型。
|
||||
- **scope** (str) - 一个字符串表示当前值的范围。以"self.var1"为例,这个var1的作用域是"self"。
|
||||
- **value** - 当前ScopedValue中保存的值。值的类型对应于 `arg_type`。
|
||||
- **arg_type** (ValueType) - 当前值的类型。
|
||||
- **scope** (str) - 字符串表示当前值的范围。以"self.var1"为例,这个var1的作用域是"self"。默认值:""。
|
||||
- **value** - 当前ScopedValue中保存的值。值的类型对应于 `arg_type`。默认值:None。
|
||||
|
||||
.. py:method:: mindspore.rewrite.ScopedValue.create_name_values(names: Union[list, tuple], scopes: Union[list, tuple] = None)
|
||||
:staticmethod:
|
||||
|
||||
创建一个ScopedValue的列表。
|
||||
创建ScopedValue的列表。
|
||||
|
||||
参数:
|
||||
- **names** (list[str] or tuple[str]) – str 的列表或元组表示引用变量的名称。
|
||||
- **scopes** (list[str] or tuple[str]) – str 的列表或元组表示引用变量的范围,默认值None表示没有指定作用范围。
|
||||
- **names** (list[str] or tuple[str]) – 引用变量的名称,类型为str的列表或元组。
|
||||
- **scopes** (list[str] or tuple[str]) – 引用变量的范围,类型为str的列表或元组。表示没有指定作用范围。默认值:None。
|
||||
|
||||
返回:
|
||||
ScopedValue的实例列表。
|
||||
|
@ -462,7 +462,7 @@ mindspore.rewrite
|
|||
|
||||
参数:
|
||||
- **name** (str) – 表示变量的字符串。
|
||||
- **scope** (str) – 表示变量范围的字符串,默认值为空字符串,表示没有指定作用范围。
|
||||
- **scope** (str) – 表示变量范围的字符串,表示没有指定作用范围。默认值:空字符串。
|
||||
|
||||
返回:
|
||||
ScopedValue的实例。
|
||||
|
@ -490,11 +490,11 @@ mindspore.rewrite
|
|||
|
||||
.. py:class:: mindspore.rewrite.PatternEngine(pattern: Union[PatternNode, List], replacement: Replacement = None)
|
||||
|
||||
PatternEngine实现了如何通过PattenNode修改SymbolTree。
|
||||
PatternEngine通过PattenNode修改SymbolTree。
|
||||
|
||||
参数:
|
||||
- **pattern** (Union[PatternNode, List]) - PatternNode的实例或用于构造 `Pattent` 的Cell类型列表。
|
||||
- **replacement** (callable) - 生成新节点的接口实现,如果为None则不进行任何匹配操作。
|
||||
- **replacement** (callable) - 生成新节点的接口实现。
|
||||
|
||||
.. py:method:: mindspore.rewrite.PatternEngine.apply(stree: SymbolTree)
|
||||
|
||||
|
@ -525,12 +525,12 @@ mindspore.rewrite
|
|||
|
||||
参数:
|
||||
- **pattern_node_name** (str) - 节点名称。
|
||||
- **match_type** (Type) - 当前节点的匹配类型。
|
||||
- **inputs** (list[PatternNode]) - 当前节点的输入节点。
|
||||
- **match_type** (Type) - 当前节点的匹配类型。默认值:None。
|
||||
- **inputs** (list[PatternNode]) - 当前节点的输入节点。默认值:None。
|
||||
|
||||
.. py:method:: mindspore.rewrite.PatternNode.add_input(node)
|
||||
|
||||
为当前节点添加一个输入。
|
||||
为当前节点添加输入。
|
||||
|
||||
参数:
|
||||
- **node** (PatternNode) - 新增的输入节点。
|
||||
|
@ -541,10 +541,10 @@ mindspore.rewrite
|
|||
.. py:method:: mindspore.rewrite.PatternNode.create_pattern_from_list(type_list: [])
|
||||
:staticmethod:
|
||||
|
||||
使用一个类型的列表来创建一个Pattern。
|
||||
使用类型的列表来创建Pattern。
|
||||
|
||||
参数:
|
||||
- **type_list** (list[type]) - 类型列表,当前支持Cell和Primitive。
|
||||
- **type_list** (list[type]) - 类型列表。
|
||||
|
||||
返回:
|
||||
根据列表生成的模式的根节点。
|
||||
|
@ -555,7 +555,7 @@ mindspore.rewrite
|
|||
.. py:method:: mindspore.rewrite.PatternNode.create_pattern_from_node(node: Node)
|
||||
:staticmethod:
|
||||
|
||||
根据一个节点及其输入创建一个Pattern。
|
||||
根据节点及其输入创建Pattern。
|
||||
|
||||
参数:
|
||||
- **node** (Node) - 要修改的节点。
|
||||
|
@ -635,7 +635,7 @@ mindspore.rewrite
|
|||
|
||||
参数:
|
||||
- **pattern** (PatternNode) - 当前模式的根节点。
|
||||
- **is_chain_pattern** (bool) - 标记模式是链模式或树模式。
|
||||
- **is_chain_pattern** (bool) - 标记,标记模式是链模式或树模式。
|
||||
- **matched** (OrderedDict) - 匹配结果,从名称映射到节点的字典。
|
||||
|
||||
返回:
|
||||
|
@ -645,7 +645,7 @@ mindspore.rewrite
|
|||
|
||||
TreeNodeHelper用于在从Tree类型节点获取 `symbol_tree` 时打破循环引用。
|
||||
|
||||
TreeNodeHelper提供了一个静态方法 `get_sub_tree` 用于从Tree类型节点获取 `symbol_tree`。
|
||||
TreeNodeHelper提供了静态方法 `get_sub_tree` 用于从Tree类型节点获取 `symbol_tree`。
|
||||
|
||||
.. py:method:: mindspore.rewrite.TreeNodeHelper.get_sub_tree(node: Node)
|
||||
:staticmethod:
|
||||
|
@ -653,11 +653,11 @@ mindspore.rewrite
|
|||
获取Tree类型节点的 `symbol_tree`。
|
||||
|
||||
参数:
|
||||
- **node** (Node) - 一个可以持有子符号树的节点。
|
||||
- **node** (Node) - 可以持有子符号树的节点。
|
||||
|
||||
返回:
|
||||
Tree节点中的SymbolTree对象。注意节点的 `symbol_tree` 可能是None,在这种情况下,方法将返回None。
|
||||
|
||||
异常:
|
||||
- **RuntimeError** - 如果参数 `node` 的 `node_type` 不是Tree类型。
|
||||
- **RuntimeError** - 如果参数 `node` 不是 NodeType.Tree类型。
|
||||
- **TypeError** - 如果参数 `node` 不是Node类型实例。
|
||||
|
|
|
@ -32,13 +32,11 @@ class Node:
|
|||
For the most part, Node represents an operator invoking in forward which could be an instance of `Cell`, an instance
|
||||
of `Primitive` or a callable method.
|
||||
|
||||
`NodeImpl` mentioned below is implementation of `Node` which is not an interface of Rewrite. Rewrite recommend
|
||||
invoking specific create method of `Node` to instantiate an instance of Node such as `create_call_cell` rather than
|
||||
invoking constructor of `Node` directly, so don't care about what is `NodeImpl` and use its instance just as a
|
||||
handler.
|
||||
|
||||
Args:
|
||||
node (NodeImpl): A handler of `NodeImpl`.
|
||||
node (NodeImpl): A handler of `NodeImpl`. `NodeImpl` mentioned below is implementation of `Node` which is not
|
||||
an interface of Rewrite. Rewrite recommend invoking specific create method of `Node` to instantiate an instance
|
||||
of Node such as `create_call_cell` rather than invoking constructor of `Node` directly, so don't care about
|
||||
what is `NodeImpl` and use its instance just as a handler.
|
||||
"""
|
||||
|
||||
def __init__(self, node: NodeImpl):
|
||||
|
@ -142,6 +140,14 @@ class Node:
|
|||
|
||||
Returns:
|
||||
A list of nodes represents users.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.rewrite import SymbolTree
|
||||
>>> from lenet import Lenet
|
||||
>>> net = Lenet()
|
||||
>>> stree = SymbolTree.create(net)
|
||||
>>> node = stree.get_node("conv1")
|
||||
>>> users = node.get_users()
|
||||
"""
|
||||
belong_symbol_tree: SymbolTreeImpl = self._node.get_belong_symbol_tree()
|
||||
if belong_symbol_tree is None:
|
||||
|
@ -164,6 +170,14 @@ class Node:
|
|||
Raises:
|
||||
TypeError: If `index` is not a `int` number.
|
||||
TypeError: If the type of `arg` is not in [`ScopedValue`, `str`].
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.rewrite import SymbolTree
|
||||
>>> from lenet import Lenet
|
||||
>>> net = Lenet()
|
||||
>>> stree = SymbolTree.create(net)
|
||||
>>> node = stree.get_node("conv1")
|
||||
>>> node.set_arg(0, "x")
|
||||
"""
|
||||
Validator.check_value_type("index", index, [int], "Node")
|
||||
Validator.check_value_type("arg", arg, [ScopedValue, str], "Node")
|
||||
|
@ -175,7 +189,7 @@ class Node:
|
|||
|
||||
def set_arg_by_node(self, arg_idx: int, src_node: 'Node', out_idx: Optional[int] = None):
|
||||
"""
|
||||
Set argument of current node by another `Node`.
|
||||
Set argument of current node by another Node.
|
||||
|
||||
Args:
|
||||
arg_idx (int): Indicate which input being modified.
|
||||
|
@ -192,6 +206,15 @@ class Node:
|
|||
TypeError: If `out_idx` is not a `int` number.
|
||||
ValueError: If `out_idx` is out of range.
|
||||
ValueError: If `src_node` has multi-outputs while `out_idx` is None or `out_idx` is not offered.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.rewrite import SymbolTree
|
||||
>>> from lenet import Lenet
|
||||
>>> net = Lenet()
|
||||
>>> stree = SymbolTree.create(net)
|
||||
>>> src_node = stree.get_node("conv1")
|
||||
>>> dst_node = stree.get_node("conv2")
|
||||
>>> dst_node.set_arg_by_node(0, src_node)
|
||||
"""
|
||||
Validator.check_value_type("arg_idx", arg_idx, [int], "Node")
|
||||
Validator.check_value_type("src_node", src_node, [Node], "Node")
|
||||
|
@ -216,6 +239,14 @@ class Node:
|
|||
|
||||
Returns:
|
||||
A list of instances of ScopedValue as targets of node.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.rewrite import SymbolTree
|
||||
>>> from lenet import Lenet
|
||||
>>> net = Lenet()
|
||||
>>> stree = SymbolTree.create(net)
|
||||
>>> node = stree.get_node("conv1")
|
||||
>>> targets = node.get_targets()
|
||||
"""
|
||||
return self._node.get_targets()
|
||||
|
||||
|
@ -227,6 +258,14 @@ class Node:
|
|||
|
||||
Returns:
|
||||
A string as name of node.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.rewrite import SymbolTree
|
||||
>>> from lenet import Lenet
|
||||
>>> net = Lenet()
|
||||
>>> stree = SymbolTree.create(net)
|
||||
>>> node = stree.get_node("conv1")
|
||||
>>> name = node.get_name()
|
||||
"""
|
||||
return self._node.get_name()
|
||||
|
||||
|
@ -236,6 +275,14 @@ class Node:
|
|||
|
||||
Returns:
|
||||
A NodeType as node_type of node.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.rewrite import SymbolTree
|
||||
>>> from lenet import Lenet
|
||||
>>> net = Lenet()
|
||||
>>> stree = SymbolTree.create(net)
|
||||
>>> node = stree.get_node("conv1")
|
||||
>>> node_type = node.get_node_type()
|
||||
"""
|
||||
return self._node.get_node_type()
|
||||
|
||||
|
@ -272,14 +319,22 @@ class Node:
|
|||
"""
|
||||
Get the arguments of current node.
|
||||
|
||||
- When node_type of current node is `CallCell`, `CallPrimitive` or `Tree`, arguments are corresponding to args
|
||||
- When `node_type` of current node is `CallCell`, `CallPrimitive` or `Tree`, arguments are corresponding to args
|
||||
of ast.Call which represents arguments to invoke forward method of cell-op or primitive-op.
|
||||
- When node_type of current node is `Input`, arguments represents default-value of argument of function.
|
||||
- When node_type of current node is `Output`, arguments represents return values.
|
||||
- When node_type of current node is `Python`, arguments are don't-care.
|
||||
- When `node_type` of current node is `Input`, arguments represents default-value of argument of function.
|
||||
- When `node_type` of current node is `Output`, arguments represents the return values of network.
|
||||
- When `node_type` of current node is `Python`, arguments are don't-care.
|
||||
|
||||
Returns:
|
||||
A list of instances of `ScopedValue`.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.rewrite import SymbolTree
|
||||
>>> from lenet import Lenet
|
||||
>>> net = Lenet()
|
||||
>>> stree = SymbolTree.create(net)
|
||||
>>> node = stree.get_node("conv1")
|
||||
>>> args = node.get_args()
|
||||
"""
|
||||
return self._node.get_args()
|
||||
|
||||
|
@ -293,6 +348,14 @@ class Node:
|
|||
|
||||
Returns:
|
||||
A dict of str to instance of `ScopedValue`.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.rewrite import SymbolTree
|
||||
>>> from lenet import Lenet
|
||||
>>> net = Lenet()
|
||||
>>> stree = SymbolTree.create(net)
|
||||
>>> node = stree.get_node("conv1")
|
||||
>>> kwargs = node.get_kwargs()
|
||||
"""
|
||||
return self._node.get_kwargs()
|
||||
|
||||
|
@ -306,6 +369,14 @@ class Node:
|
|||
|
||||
Raises:
|
||||
TypeError: If `key` is not a `str`.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.rewrite import SymbolTree
|
||||
>>> from lenet import Lenet
|
||||
>>> net = Lenet()
|
||||
>>> stree = SymbolTree.create(net)
|
||||
>>> node = stree.get_node("conv1")
|
||||
>>> node.set_attribute("channel", 3)
|
||||
"""
|
||||
Validator.check_value_type("key", key, [str], "Node attribute")
|
||||
self._node.set_attribute(key, value)
|
||||
|
@ -327,7 +398,7 @@ class Node:
|
|||
key (str): Key of attribute.
|
||||
|
||||
Returns:
|
||||
A object as attribute.
|
||||
A object as attribute, can be any type.
|
||||
|
||||
Raises:
|
||||
TypeError: If `key` is not a `str`.
|
||||
|
|
|
@ -33,8 +33,8 @@ class PatternNode:
|
|||
|
||||
Args:
|
||||
pattern_node_name (str): Name of current node.
|
||||
match_type (Type): A type represents what type would be matched of current node.
|
||||
inputs (list[PatternNode]): Input nodes of current node.
|
||||
match_type (Type): A type represents what type would be matched of current node. Default: None.
|
||||
inputs (list[PatternNode]): Input nodes of current node. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self, pattern_node_name: str, match_type: Type = Type[None], inputs: ['PatternNode'] = None):
|
||||
|
@ -166,13 +166,13 @@ class PatternNode:
|
|||
|
||||
def name(self) -> str:
|
||||
"""
|
||||
Getter of name.
|
||||
Getter of PatternNode name.
|
||||
"""
|
||||
return self._name
|
||||
|
||||
def type(self):
|
||||
"""
|
||||
Getter of type.
|
||||
Getter of PatternNode type.
|
||||
"""
|
||||
return self._type
|
||||
|
||||
|
@ -192,8 +192,17 @@ class VarNode(PatternNode):
|
|||
class Replacement(abc.ABC):
|
||||
"""
|
||||
Interface of replacement function.
|
||||
"""
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.rewrite import Replacement, Node
|
||||
>>> from mindspore.nn import nn
|
||||
>>> class BnReplacement(Replacement):
|
||||
... def build(self, pattern, is_chain_pattern: bool, matched):
|
||||
... bn_node: Node = matched.get(pattern.name())
|
||||
... conv = nn.Conv2d(16, 16, 3)
|
||||
... conv_node = Node.create_call_cell(conv, ['x1'], bn_node.get_args(), bn_node.get_kwargs())
|
||||
... return [conv_node]
|
||||
"""
|
||||
@abc.abstractmethod
|
||||
def build(self, pattern: PatternNode, is_chain_pattern: bool, matched: OrderedDict) -> [Node]:
|
||||
"""
|
||||
|
|
|
@ -73,6 +73,10 @@ class ScopedValue:
|
|||
|
||||
Returns:
|
||||
An instance of `ScopedValue`.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.rewrite import ScopedValue
|
||||
>>> variable = ScopedValue.create_variable_value(2)
|
||||
"""
|
||||
if isinstance(value, int):
|
||||
return cls(ValueType.IntValue, "", value)
|
||||
|
@ -108,6 +112,10 @@ class ScopedValue:
|
|||
Raises:
|
||||
TypeError: If `name` is not `str`.
|
||||
TypeError: If `scope` is not `str`.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.rewrite import ScopedValue
|
||||
>>> variable = ScopedValue.create_naming_value("conv", "self")
|
||||
"""
|
||||
Validator.check_value_type("name", name, [str], "ScopedValue")
|
||||
Validator.check_value_type("scope", scope, [str], "ScopedValue")
|
||||
|
@ -129,6 +137,10 @@ class ScopedValue:
|
|||
RuntimeError: If the length of names is not equal to the length of scopes when scopes are not None.
|
||||
TypeError: If `names` is not `list` or `tuple` and name in `names` is not `str`.
|
||||
TypeError: If `scopes` is not `list` or `tuple` and scope in `scopes` is not `str`.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.rewrite import ScopedValue
|
||||
>>> variables = ScopedValue.create_name_values(["z", "z_1"]), name="subnet")
|
||||
"""
|
||||
Validator.check_element_type_of_iterable("names", names, [str], "ScopedValue")
|
||||
if scopes is not None:
|
||||
|
|
|
@ -99,6 +99,14 @@ class SymbolTree:
|
|||
TypeError: If the type of `targets` is not str.
|
||||
TypeError: If arg in `args` is not ParamType.
|
||||
TypeError: If key of `kwarg` is not a str or value of kwarg in `kwargs` is not ParamType.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.rewrite import SymbolTree
|
||||
>>> from lenet import Lenet
|
||||
>>> net = Lenet()
|
||||
>>> stree = SymbolTree.create(net)
|
||||
>>> node = stree.get_node("conv1")
|
||||
>>> new_node = stree.create_call_function(F.abs, ["x"], node)
|
||||
"""
|
||||
Validator.check_value_type("func", func, [FunctionType], "SymbolTree node")
|
||||
Validator.check_element_type_of_iterable("targets", targets, [str], "SymbolTree node")
|
||||
|
@ -119,6 +127,13 @@ class SymbolTree:
|
|||
|
||||
Returns:
|
||||
An instance of `SymbolTree`.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.rewrite import SymbolTree
|
||||
>>> from lenet import Lenet
|
||||
>>> net = Lenet()
|
||||
>>> stree = SymbolTree.create(net)
|
||||
>>> handler = stree.get_handler()
|
||||
"""
|
||||
return self._symbol_tree
|
||||
|
||||
|
@ -128,6 +143,14 @@ class SymbolTree:
|
|||
|
||||
Returns:
|
||||
A generator for node of current `SymbolTree`.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.rewrite import SymbolTree
|
||||
>>> from lenet import Lenet
|
||||
>>> net = Lenet()
|
||||
>>> stree = SymbolTree.create(net)
|
||||
>>> for node in stree.nodes():
|
||||
... node.set_attribute("channel", 3)
|
||||
"""
|
||||
for node in self._symbol_tree.nodes():
|
||||
yield Node(node)
|
||||
|
@ -144,6 +167,13 @@ class SymbolTree:
|
|||
|
||||
Raises:
|
||||
TypeError: If `node_name` is not `str`.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.rewrite import SymbolTree
|
||||
>>> from lenet import Lenet
|
||||
>>> net = Lenet()
|
||||
>>> stree = SymbolTree.create(net)
|
||||
>>> node = stree.get_node("conv1")
|
||||
"""
|
||||
Validator.check_value_type("node_name", node_name, [str], "SymbolTree")
|
||||
node_impl = self._symbol_tree.get_node(node_name)
|
||||
|
@ -157,6 +187,13 @@ class SymbolTree:
|
|||
|
||||
Returns:
|
||||
[Node], the node list of the current `Symboltree`.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.rewrite import SymbolTree
|
||||
>>> from lenet import Lenet
|
||||
>>> net = Lenet()
|
||||
>>> stree = SymbolTree.create(net)
|
||||
>>> inputs = stree.get_inputs()
|
||||
"""
|
||||
return [Node(node_impl) for node_impl in self._symbol_tree.get_inputs()]
|
||||
|
||||
|
@ -176,6 +213,15 @@ class SymbolTree:
|
|||
|
||||
Raises:
|
||||
TypeError: if `node` is not a `Node`.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.rewrite import SymbolTree
|
||||
>>> from lenet import Lenet
|
||||
>>> net = Lenet()
|
||||
>>> stree = SymbolTree.create(net)
|
||||
>>> for node in stree.nodes():
|
||||
... if node.get_name() == "conv1":
|
||||
... position = stree.before(node)
|
||||
"""
|
||||
Validator.check_value_type("node", node, [Node], "SymbolTree")
|
||||
return self._symbol_tree.before(node.get_handler())
|
||||
|
@ -196,6 +242,15 @@ class SymbolTree:
|
|||
|
||||
Raises:
|
||||
TypeError: If `node` is not a `Node`.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.rewrite import SymbolTree
|
||||
>>> from lenet import Lenet
|
||||
>>> net = Lenet()
|
||||
>>> stree = SymbolTree.create(net)
|
||||
>>> for node in stree.nodes():
|
||||
... if node.get_name() == "conv1":
|
||||
... position = stree.after(node)
|
||||
"""
|
||||
Validator.check_value_type("node", node, [Node], "SymbolTree")
|
||||
return self._symbol_tree.after(node.get_handler())
|
||||
|
@ -218,6 +273,16 @@ class SymbolTree:
|
|||
RuntimeError: If `position` is not belong to current `SymbolTree`.
|
||||
TypeError: If `position` is not a `Position`.
|
||||
TypeError: If `node` is not a `Node`.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.rewrite import SymbolTree
|
||||
>>> from lenet import Lenet
|
||||
>>> net = Lenet()
|
||||
>>> stree = SymbolTree.create(net)
|
||||
>>> node = stree.get_node("conv1")
|
||||
>>> position = stree.after(node)
|
||||
>>> new_node = stree.create_call_function(F.abs, ["x"], node)
|
||||
>>> stree.insert(position, new_node)
|
||||
"""
|
||||
Validator.check_value_type("position", position, [Position], "SymbolTree")
|
||||
Validator.check_value_type("node", node, [Node], "SymbolTree")
|
||||
|
@ -235,6 +300,18 @@ class SymbolTree:
|
|||
|
||||
Raises:
|
||||
TypeError: If `node` is not a `Node`.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.rewrite import SymbolTree
|
||||
>>> from lenet import Lenet
|
||||
>>> net = Lenet()
|
||||
>>> stree = SymbolTree.create(net)
|
||||
>>> node = stree.get_node("conv1")
|
||||
>>> input_node = node.get_inputs()[0]
|
||||
>>> output_nodes = node.get_users()
|
||||
>>> for n in output_nodes:
|
||||
... n.set_arg(0, "x")
|
||||
>>> stree.erase_node(node)
|
||||
"""
|
||||
Validator.check_value_type("node", node, [Node], "SymbolTree")
|
||||
return Node(self._symbol_tree.erase_node(node.get_handler()))
|
||||
|
@ -264,9 +341,18 @@ class SymbolTree:
|
|||
An instance of Node represents root of node_tree been replaced in.
|
||||
|
||||
Raises:
|
||||
RuntimeError: Old node is isolated.
|
||||
RuntimeError: Old node is not isolated.
|
||||
TypeError: If `old_node` is not a `Node`.
|
||||
TypeError: If `new_nodes` is not a `list` or node in `new_nodes` is not a `Node`.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.rewrite import SymbolTree
|
||||
>>> from lenet import Lenet
|
||||
>>> net = Lenet()
|
||||
>>> stree = SymbolTree.create(net)
|
||||
>>> node = stree.get_node("conv1")
|
||||
>>> new_node = stree.create_call_function(F.abs, ["x"], node)
|
||||
>>> stree.replace(node, [new_node])
|
||||
"""
|
||||
Validator.check_value_type("old_node", old_node, [Node], "SymbolTree")
|
||||
Validator.check_element_type_of_iterable("new_nodes", new_nodes, [Node], "SymbolTree")
|
||||
|
@ -288,6 +374,13 @@ class SymbolTree:
|
|||
RuntimeError: If `index` is out of range.
|
||||
TypeError: If `index` is not a `int` number.
|
||||
TypeError: If `return_value` is not a `str`.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.rewrite import SymbolTree
|
||||
>>> from lenet import Lenet
|
||||
>>> net = Lenet()
|
||||
>>> stree = SymbolTree.create(net)
|
||||
>>> stree.set_output(0, "x_10")
|
||||
"""
|
||||
Validator.check_value_type("index", index, [int], "SymbolTree")
|
||||
Validator.check_value_type("return_value", return_value, [str], "SymbolTree")
|
||||
|
@ -295,13 +388,20 @@ class SymbolTree:
|
|||
|
||||
def dump(self):
|
||||
"""
|
||||
Dump graph to console.
|
||||
Print the ir map information corresponding to the network in 'SymbolTree' to the screen.
|
||||
"""
|
||||
self._symbol_tree.dump()
|
||||
|
||||
def print_node_tabulate(self):
|
||||
"""
|
||||
Print node information of graph.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.rewrite import SymbolTree
|
||||
>>> from lenet import Lenet
|
||||
>>> net = Lenet()
|
||||
>>> stree = SymbolTree.create(net)
|
||||
>>> stree.print_node_tabulate()
|
||||
"""
|
||||
self._symbol_tree.print_node_tabulate()
|
||||
|
||||
|
@ -311,6 +411,13 @@ class SymbolTree:
|
|||
|
||||
Returns:
|
||||
A str represents source code of modified network.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.rewrite import SymbolTree
|
||||
>>> from lenet import Lenet
|
||||
>>> net = Lenet()
|
||||
>>> stree = SymbolTree.create(net)
|
||||
>>> stree.get_code()
|
||||
"""
|
||||
return self._symbol_tree.get_code()
|
||||
|
||||
|
@ -321,6 +428,13 @@ class SymbolTree:
|
|||
|
||||
Returns:
|
||||
A network object.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.rewrite import SymbolTree
|
||||
>>> from lenet import Lenet
|
||||
>>> net = Lenet()
|
||||
>>> stree = SymbolTree.create(net)
|
||||
>>> stree.get_network()
|
||||
"""
|
||||
return self._symbol_tree.get_network()
|
||||
|
||||
|
@ -330,16 +444,40 @@ class SymbolTree:
|
|||
|
||||
Args:
|
||||
file_name (str): filename to be set.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.rewrite import SymbolTree
|
||||
>>> from lenet import Lenet
|
||||
>>> net = Lenet()
|
||||
>>> stree = SymbolTree.create(net)
|
||||
>>> stree.set_saved_file_name("new_net")
|
||||
"""
|
||||
Validator.check_value_type("file_name", file_name, [str], "Saving network")
|
||||
self._symbol_tree.set_saved_file_name(file_name)
|
||||
|
||||
def get_saved_file_name(self):
|
||||
"""Gets the filename used to save the network."""
|
||||
"""
|
||||
Gets the filename used to save the network.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.rewrite import SymbolTree
|
||||
>>> from lenet import Lenet
|
||||
>>> net = Lenet()
|
||||
>>> stree = SymbolTree.create(net)x
|
||||
>>> stree.set_saved_file_name("new_net")
|
||||
>>> stree.get_saved_file_name()
|
||||
"""
|
||||
return self._symbol_tree.get_saved_file_name()
|
||||
|
||||
def save_network_to_file(self):
|
||||
"""
|
||||
Save the modified network to a file. Default file name is `network_define.py`.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.rewrite import SymbolTree
|
||||
>>> from lenet import Lenet
|
||||
>>> net = Lenet()
|
||||
>>> stree = SymbolTree.create(net)
|
||||
>>> stree.save_network_to_file()
|
||||
"""
|
||||
self._symbol_tree.save_network_to_file()
|
||||
|
|
|
@ -36,7 +36,7 @@ class TreeNodeHelper:
|
|||
Getting symbol_tree from a `Tree` type `Node`.
|
||||
|
||||
Args:
|
||||
node (Node): A `Node` who may hold a sub-symbol_tree.
|
||||
node (Node): A `Node` which may hold a sub-symbol_tree.
|
||||
|
||||
Returns:
|
||||
An instance of SymbolTree represents sub-symbol_tree. Note that `node`'s symbol_tree maybe None, in this
|
||||
|
|
Loading…
Reference in New Issue