rewrite教程放到隐藏页;不展示非转测的接口"

This commit is contained in:
guozhibin 2023-06-27 16:18:11 +08:00
parent 1e12fb9405
commit 0dabd014b2
7 changed files with 821 additions and 1031 deletions

View File

@ -66,4 +66,8 @@ https://www.mindspore.cn/docs/zh-CN/master/api_python/samples/dataset/vision_gal
https://gitee.com/mindspore/mindspore/blob/master/docs/api/api_python/samples/dataset/audio_gallery.ipynb
https://gitee.com/mindspore/mindspore/blob/master/docs/api/api_python/samples/dataset/dataset_gallery.ipynb
https://gitee.com/mindspore/mindspore/blob/master/docs/api/api_python/samples/dataset/text_gallery.ipynb
https://gitee.com/mindspore/mindspore/blob/master/docs/api/api_python/samples/dataset/vision_gallery.ipynb
https://gitee.com/mindspore/mindspore/blob/master/docs/api/api_python/samples/dataset/vision_gallery.ipynb
https://www.mindspore.cn/docs/zh-CN/master/api_python/samples/rewrite/rewrite_tutorial.html
https://www.mindspore.cn/docs/en/master/api_python/samples/rewrite/rewrite_tutorial.html
https://gitee.com/mindspore/mindspore/blob/master/docs/api/api_python/samples/rewrite/rewrite_tutorial.md
https://gitee.com/mindspore/mindspore/blob/master/docs/api/api_python_en/samples/rewrite/rewrite_tutorial.md

View File

@ -2,388 +2,7 @@ mindspore.rewrite
=================
MindSpore的ReWrite模块为用户提供了基于自定义规则对网络的前向计算过程进行修改的能力如插入、删除和替换语句。
使用教程
--------
ReWrite完整示例请参考
`rewrite_example.py <https://gitee.com/mindspore/mindspore/blob/master/docs/api/api_python/rewrite_example.py>`_
该样例代码的主要功能包括怎么通过网络创建SymbolTree并且对SymbolTree中的节点进行插入、删除、替换等操作
其中还包含了对子网络的修改和通过模式匹配进行节点替换。
功能介绍
^^^^^^^^
ReWrite模块使用SymbolTree记录一个网络的前向计算过程其中计算过程的每条代码语句会被展开并以节点的形式存储在SymbolTree中。
ReWrite模块提供了一组新的接口用户可以使用这组接口为一个网络创建SymbolTree然后对SymbolTree里的节点进行修改从而实现对
网络前向计算过程的修改。最后得到修改后的网络代码,或者一个新的网络实例。
创建SymbolTree
^^^^^^^^^^^^^^^
当用户需要使用ReWrite模块对一个网络进行修改时首先需要基于该网络的实例创建一个SymbolTree使用的接口
:func:`mindspore.rewrite.SymbolTree.create`
通过接口 :func:`mindspore.rewrite.SymbolTree.get_code` 可以查看当前SymbolTree里存储的网络代码。
.. code-block:: python
import mindspore.nn as nn
from mindspore.rewrite import SymbolTree
class MyNet(nn.Cell):
def __init__(self):
super().__init__()
self.dense = nn.Dense(in_channels=32, out_channels=32, has_bias=False, weight_init="ones")
self.relu = nn.ReLU()
def construct(self, x):
x = self.dense(x)
x = self.relu(x)
return x
net = MyNet()
stree = SymbolTree.create(net)
print(stree.get_code())
运行结果如下:
.. code-block:: python
import sys
sys.path.append('...') # Current working directory
import mindspore
from mindspore import nn
import mindspore.nn as nn
class MyNetOpt(nn.Cell):
def __init__(self, obj):
super().__init__()
for (key, value) in obj.__dict__.items():
setattr(self, key, value)
def construct(self, x):
x = self.dense(x)
x = self.relu(x)
return x
可以看到,通过解析网络 `MyNet` SymbolTree里存储的新网络的类名是 `MyNetOpt` ,相较原网络增加了后缀 ``Opt``
同时init函数的参数和内容均发生了改动新增参数 `obj` 传入的是原始网络的实例,函数里将原始网络的属性信息拷贝到了新的网络里。
新的网络还将当前工作目录保存到 ``sys.path`` 里,从而保证新网络运行时可以搜索到原网络依赖的模块。
通过接口 :func:`mindspore.rewrite.SymbolTree.print_node_tabulate` 可以看到SymbolTree里存储的节点信息及节点拓扑关系。
该接口依赖tabulate模块安装指令为 ``pip install tabulate``
.. code-block:: python
stree.print_node_tabulate()
运行结果如下:
.. code-block::
================================================================================
node type name codes arg providers target users
----------------- ------- ----------------- --------------------- ----------------------
NodeType.Input input_x x [] [[0, [('dense', 0)]]]
NodeType.CallCell dense x = self.dense(x) [[0, ('input_x', 0)]] [[0, [('relu', 0)]]]
NodeType.CallCell relu x = self.relu(x) [[0, ('dense', 0)]] [[0, [('return', 0)]]]
NodeType.Output return return x [[0, ('relu', 0)]] []
==================================================================================
可以看到,网络的前向计算过程的每一条语句均被转换为一个节点,其中每一个节点的名称是唯一的。
SymbolTree里记录了各个Node间的拓扑关系即节点的某个输入来自哪个节点的第几个输出以及节点的某个输出被哪些节点的哪个输入使用。
当前向计算过程中存在复杂语句时创建SymbolTree的过程会将语句展开然后再将展开后的每个语句转换为节点。
.. code-block:: python
import mindspore.nn as nn
from mindspore.rewrite import SymbolTree
class MyNet_2(nn.Cell):
def __init__(self):
super().__init__()
self.dense = nn.Dense(in_channels=32, out_channels=32, has_bias=False, weight_init="ones")
self.relu = nn.ReLU()
def construct(self, x):
x = self.relu(0.5 * self.dense(x))
return x
net = MyNet_2()
stree = SymbolTree.create(net)
stree.print_node_tabulate()
运行结果如下:
.. code-block::
================================================================================
node type name codes arg providers target users
----------------- ---------- ------------------------ ------------------------ --------------------------
NodeType.Input input_x x [] [[0, [('dense', 0)]]]
NodeType.CallCell dense dense = self.dense(x) [[0, ('input_x', 0)]] [[0, [('binop_mult', 1)]]]
NodeType.MathOps binop_mult mult_var = (0.5 * dense) [[1, ('dense', 0)]] [[0, [('relu', 0)]]]
NodeType.CallCell relu x = self.relu(mult_var) [[0, ('binop_mult', 0)]] [[0, [('return', 0)]]]
NodeType.Output return return x [[0, ('relu', 0)]] []
==================================================================================
可以看到前向计算过程中写在同一行的dense操作、乘法操作和relu操作被展开为三行代码然后被转换为三个对应节点。
插入节点
^^^^^^^^
当需要在网络的前向计算过程中插入一行新的代码时,可以先使用接口 :func:`mindspore.rewrite.Node.create_call_cell` 创建一个新
的节点,然后使用接口 :func:`mindspore.rewrite.SymbolTree.insert` 将创建的节点插入到SymbolTree内。
.. code-block:: python
from mindspore.rewrite import SymbolTree, Node, ScopedValue
net = MyNet()
stree = SymbolTree.create(net)
new_relu_cell = nn.ReLU()
new_node = Node.create_call_cell(cell=new_relu_cell, targets=["x"],
args=[ScopedValue.create_naming_value("x")], name="new_relu")
dense_node = stree.get_node("dense")
stree.insert(stree.after(dense_node), new_node)
stree.print_node_tabulate()
在该样例中,插入节点的流程如下:
1. 首先创建了一个新的节点使用的Cell是 ``nn.ReLU()`` ,输入输出均为 ``"x"`` ,节点名是 ``"new_relu"``
2. 接着通过 :func:`mindspore.rewrite.SymbolTree.get_node` 方法获取dense节点。
3. 最后通过 :func:`mindspore.rewrite.SymbolTree.insert` 方法将新创建的节点插入到dense节点后面。
运行结果如下:
.. code-block::
================================================================================
node type name codes arg providers target users
----------------- -------- -------------------- ---------------------- ------------------------
NodeType.Input input_x x [] [[0, [('dense', 0)]]]
NodeType.CallCell dense x = self.dense(x) [[0, ('input_x', 0)]] [[0, [('new_relu', 0)]]]
NodeType.CallCell new_relu x = self.new_relu(x) [[0, ('dense', 0)]] [[0, [('relu', 0)]]]
NodeType.CallCell relu x = self.relu(x) [[0, ('new_relu', 0)]] [[0, [('return', 0)]]]
NodeType.Output return return x [[0, ('relu', 0)]] []
==================================================================================
可以看到新的new_relu节点插入到dense节点和relu节点间节点的拓扑结构随着节点插入自动更新。
其中,新节点对应代码里的 `self.new_relu` 定义在新网络的init函数里使用传入的 `new_relu_cell` 作为实例。
除了使用 :func:`mindspore.rewrite.SymbolTree.get_node` 方法获取节点来指定插入位置,还可以
通过 :func:`mindspore.rewrite.SymbolTree.nodes` 来遍历节点,并使用 :func:`mindspore.rewrite.SymbolTree.get_instance_type`
基于节点对应实例的类型来获取节点,确定插入位置。
.. code-block:: python
for node in stree.nodes():
if node.get_instance_type() == nn.Dense:
stree.insert(stree.after(node), new_node)
如果希望插入新代码的输出不复用原始网络里的变量,可以在创建节点时使用 :func:`mindspore.rewrite.SymbolTree.unique_name`
到一个SymbolTree内不重名的变量名作为节点的输出。
然后在插入节点前,通过使用 :func:`mindspore.rewrite.Node.set_arg` 修改节点输入变量名,设置哪些节点使用新的节点输出作为输入。
.. code-block:: python
from mindspore.rewrite import SymbolTree, Node, ScopedValue
net = MyNet()
stree = SymbolTree.create(net)
new_relu_cell = nn.ReLU()
new_node = Node.create_call_cell(cell=new_relu_cell, targets=[stree.unique_name("x")],
args=[ScopedValue.create_naming_value("x")], name="new_relu")
dense_node = stree.get_node("dense")
stree.insert(stree.after(dense_node), new_node)
old_relu_node = stree.get_node("relu")
old_relu_node.set_arg(0, new_node.get_targets()[0])
stree.print_node_tabulate()
在该样例中,创建新节点时 `targets` 参数的值进行了不重名的处理然后将旧的relu节点的输入改为新节点的输出。
运行结果如下:
.. code-block::
================================================================================
node type name codes arg providers target users
----------------- -------- ---------------------- ---------------------- ------------------------
NodeType.Input input_x x [] [[0, [('dense', 0)]]]
NodeType.CallCell dense x = self.dense(x) [[0, ('input_x', 0)]] [[0, [('new_relu', 0)]]]
NodeType.CallCell new_relu x_1 = self.new_relu(x) [[0, ('dense', 0)]] [[0, [('relu', 0)]]]
NodeType.CallCell relu x = self.relu(x_1) [[0, ('new_relu', 0)]] [[0, [('return', 0)]]]
NodeType.Output return return x [[0, ('relu', 0)]] []
==================================================================================
可以看到,新节点的输出变量名是一个不重名的名称 ``x_1`` 且旧的relu节点使用 ``x_1`` 作为输入。
删除节点
^^^^^^^^
当需要在网络的前向计算过程中删除一行代码时,可以使用接口 :func:`mindspore.rewrite.SymbolTree.erase` 来删除节点。
节点删除后,符号树内剩余节点的拓扑关系会依据删除后的代码情况自动更新。
因此,当待删除的节点的输出被别的节点使用时,节点删除后,需要注意剩余节点的拓扑关系是否符合设计预期。
如果待删除节点的前面存在某个节点的输出名和待删除节点的输出名重名,删除节点后,后续使用该输出名作为输入的节点,自动使用前面那个节点
的输出作为输入。拓扑关系会按照该策略更新。
.. code-block:: python
from mindspore.rewrite import SymbolTree, Node, ScopedValue
net = MyNet()
stree = SymbolTree.create(net)
relu_node = stree.get_node("relu")
stree.erase(relu_node)
stree.print_node_tabulate()
运行结果如下:
.. code-block::
================================================================================
node type name codes arg providers target users
----------------- ------- ----------------- --------------------- ----------------------
NodeType.Input input_x x [] [[0, [('dense', 0)]]]
NodeType.CallCell dense x = self.dense(x) [[0, ('input_x', 0)]] [[0, [('return', 0)]]]
NodeType.Output return return x [[0, ('dense', 0)]] []
==================================================================================
可以看到因为dense结点的输出和relu结点的输出同名删除relu节点后返回值使用的是dense节点的输出。
如果待删除节点的前面不存在和待删除节点同名的输出,则需要用户先修改后续使用该输出作为输入的节点,更新参数名,然后再
删除节点,以避免删除节点后发生使用了未定义变量的错误。
.. code-block:: python
import mindspore.nn as nn
from mindspore.rewrite import SymbolTree
class MyNet_3(nn.Cell):
def __init__(self):
super().__init__()
self.dense = nn.Dense(in_channels=32, out_channels=32, has_bias=False, weight_init="ones")
self.relu = nn.ReLU()
def construct(self, x):
y = self.dense(x)
z = self.relu(y)
return z
net = MyNet_3()
stree = SymbolTree.create(net)
relu_node = stree.get_node("relu")
for node in relu_node.get_users():
node.set_arg(0, relu_node.get_args()[0])
stree.erase(relu_node)
stree.print_node_tabulate()
在该样例中拿到relu节点后先使用接口 :func:`mindspore.rewrite.Node.get_users` 遍历使用relu节点的输出作为输入的节点将这些
节点的输入都改为relu节点的输入然后再删除relu节点。这样的话后续使用了relu节点输出 ``z`` 的地方就都改为使用relu节点输入 ``y`` 了。
具体的参数名修改策略取决于实际场景需求。
运行结果如下:
.. code-block::
================================================================================
node type name codes arg providers target users
----------------- ------- ----------------- --------------------- ----------------------
NodeType.Input input_x x [] [[0, [('dense', 0)]]]
NodeType.CallCell dense y = self.dense(x) [[0, ('input_x', 0)]] [[0, [('return', 0)]]]
NodeType.Output return return y [[0, ('dense', 0)]] []
==================================================================================
可以看到删除relu节点后最后一个return节点的值从 ``z`` 被更新为 ``y``
替换节点
^^^^^^^^
当需要在网络的前向计算过程中替换代码时,可以使用接口 :func:`mindspore.rewrite.SymbolTree.replace` 来替换节点。
.. code-block:: python
from mindspore.rewrite import SymbolTree, Node, ScopedValue
net = MyNet()
stree = SymbolTree.create(net)
new_relu_cell = nn.ReLU()
new_node = Node.create_call_cell(cell=new_relu_cell, targets=["x"],
args=[ScopedValue.create_naming_value("x")], name="new_relu")
relu_node = stree.get_node("relu")
stree.replace(relu_node, [new_node])
stree.print_node_tabulate()
该样例将原始网络里的relu节点替换为new_relu节点运行结果如下
.. code-block::
================================================================================
node type name codes arg providers target users
----------------- -------- -------------------- ---------------------- ------------------------
NodeType.Input input_x x [] [[0, [('dense', 0)]]]
NodeType.CallCell dense x = self.dense(x) [[0, ('input_x', 0)]] [[0, [('new_relu', 0)]]]
NodeType.CallCell new_relu x = self.new_relu(x) [[0, ('dense', 0)]] [[0, [('return', 0)]]]
NodeType.Output return return x [[0, ('new_relu', 0)]] []
==================================================================================
如果替换的新节点的输出和被替换节点的输出名不一致,需要注意维护好替换后的节点间的拓扑关系,即先修改后续使用了被替换节点的输出的节点,
更新这些节点的参数名,然后再进行节点替换操作。
.. code-block:: python
from mindspore.rewrite import SymbolTree, Node, ScopedValue
net = MyNet()
stree = SymbolTree.create(net)
# Update the parameter names of subsequent nodes
relu_node = stree.get_node("relu")
for node in relu_node.get_users():
node.set_arg(0, "y1")
# Create two new nodes
new_relu_cell = nn.ReLU()
new_node = Node.create_call_cell(cell=new_relu_cell, targets=["y1"],
args=[ScopedValue.create_naming_value("x")], name="new_relu_1")
new_relu_cell_2 = nn.ReLU()
new_node_2 = Node.create_call_cell(cell=new_relu_cell_2, targets=["y2"],
args=[ScopedValue.create_naming_value("x")], name="new_relu_2")
# Replace relu node with two new nodes
stree.replace(relu_node, [new_node, new_node_2])
stree.print_node_tabulate()
该用例将relu节点替换为两个新的节点其中第一个节点的输出 ``y1`` 作为返回值更新return节点。运行结果如下
.. code-block::
================================================================================
node type name codes arg providers target users
----------------- ---------- ----------------------- ---------------------- -------------------------------------------
NodeType.Input input_x x [] [[0, [('dense', 0)]]]
NodeType.CallCell dense x = self.dense(x) [[0, ('input_x', 0)]] [[0, [('new_relu', 0), ('new_relu_1', 0)]]]
NodeType.CallCell new_relu y1 = self.new_relu(x) [[0, ('dense', 0)]] [[0, [('return', 0)]]]
NodeType.CallCell new_relu_1 y2 = self.new_relu_1(x) [[0, ('dense', 0)]] []
NodeType.Output return return y1 [[0, ('new_relu', 0)]] []
==================================================================================
可以看出relu节点被成功替换为两个新节点返回值也被更新为第一个新节点的输出。
返回新网络
^^^^^^^^^^
当对网络修改完毕后,就可以使用接口 :func:`mindspore.rewrite.SymbolTree.get_network` 得到修改后的网络实例了。
.. code-block:: python
new_net = stree.get_network()
inputs = Tensor(np.ones([1, 1, 32, 32]), mstype.float32)
outputs = new_net(inputs)
调用该接口后Rewrite模块会先在当前工作目录的rewritten_network文件夹下生成修改后的网络对应的脚本文件然后使用该脚本文件创建新的网络实例
原网络的实例作为参数使用。新的网络实例可以直接用于计算和训练。
如何快速使用ReWrite请参考 `使用ReWrite修改网络 <https://www.mindspore.cn/docs/zh-CN/master/api_python/samples/rewrite/rewrite_tutorial.html>`_
.. py:class:: mindspore.rewrite.SymbolTree(handler: SymbolTreeImpl)
@ -537,6 +156,8 @@ SymbolTree里记录了各个Node间的拓扑关系即节点的某个输入来
返回:
str一个SymbolTree内唯一的新的名称名称格式为 `name_n` ,其中 `n` 为数字下标。如果输入 `name` 没有名称冲突,则没有数字下标。
异常:
- **TypeError** - 如果参数 `name` 不是str类型。
.. py:class:: mindspore.rewrite.Node(node: NodeImpl)
@ -615,6 +236,13 @@ SymbolTree里记录了各个Node间的拓扑关系即节点的某个输入来
返回:
NodeType当前节点的类型。
.. py:method:: mindspore.rewrite.Node.get_targets()
获取当前节点的输出值列表。
返回:
输出值列表,参数类型为 ``ScopedValue``
.. py:method:: mindspore.rewrite.Node.get_users()
获取一个节点列表,列表里的节点使用当前节点的输出作为输入。
@ -726,233 +354,3 @@ SymbolTree里记录了各个Node间的拓扑关系即节点的某个输入来
- NamingValue表示对另一个变量的引用。
- CustomObjValue表示自定义类的实例或类型超出ValueType的基本类型和容器类型范围的对象。
.. py:class:: mindspore.rewrite.PatternEngine(pattern: Union[PatternNode, List], replacement: Replacement = None)
PatternEngine通过PattenNode修改SymbolTree。
.. warning::
- 这是一组实验性API后续可能修改或删除。
参数:
- **pattern** (Union[PatternNode, List]) - PatternNode的实例或用于构造 `Pattent` 的Cell类型列表。
- **replacement** (callable) - 生成新节点的接口实现。默认值: ``None``
.. py:method:: mindspore.rewrite.PatternEngine.apply(stree: SymbolTree)
`stree` 上面执行当前的匹配模式。
.. note::
当前还不支持子树节点。
参数:
- **stree** (SymbolTree) - 要修改的SymbolTree。
返回:
bool表示是否对 `stree` 进行了修改。
异常:
- **TypeError** - 如果参数 `stree` 不是SymbolTree类型。
.. py:method:: mindspore.rewrite.PatternEngine.pattern()
获取当前的匹配模式。
返回:
PattenNode的实例用来说明当前模式需要匹配的类型。
.. py:class:: mindspore.rewrite.PatternNode(pattern_node_name: str, match_type: Type = Type[None], inputs: ['PatternNode'] = None)
PatternNode在定义 `pattern` 时被定义为一个节点。
.. warning::
- 这是一组实验性API后续可能修改或删除。
参数:
- **pattern_node_name** (str) - 节点名称。
- **match_type** (Type) - 当前节点的匹配类型。默认值: ``Type[None]``
- **inputs** (list[PatternNode]) - 当前节点的输入节点。默认值: ``None``
.. py:method:: mindspore.rewrite.PatternNode.add_input(node)
为当前节点添加输入。
参数:
- **node** (PatternNode) - 新增的输入节点。
异常:
- **TypeError** - 如果参数 `node` 不是PattenNode类型。
.. py:method:: mindspore.rewrite.PatternNode.create_pattern_from_list(type_list: [])
:staticmethod:
使用类型的列表来创建Pattern。
参数:
- **type_list** (list[type]) - 类型列表。
返回:
根据列表生成的模式的根节点。
异常:
- **TypeError** - 如果 `type_list` 不是list类型。
.. py:method:: mindspore.rewrite.PatternNode.create_pattern_from_node(node: Node)
:staticmethod:
根据节点及其输入创建Pattern。
参数:
- **node** (Node) - 要修改的节点。
返回:
根据 `node` 创建的PattentNode。
异常:
- **TypeError** - 如果 `node` 不是Node类型。
.. py:method:: mindspore.rewrite.PatternNode.from_node(node: Node)
:staticmethod:
根据 `node` 创建PatternNode。
参数:
- **node** (Node) - 要修改的节点。
返回:
根据 `node` 创建的PattentNode。
异常:
- **TypeError** - 如果 `node` 不是Node类型。
.. py:method:: mindspore.rewrite.PatternNode.get_inputs()
获取当前节点的输入。
返回:
PattenNode的实例列表当前节点的输入节点。
.. py:method:: mindspore.rewrite.PatternNode.match(node: Node)
检查当前PatternNode是否可以与node匹配。
参数:
- **node** (Node) - 要匹配的节点。
异常:
- **TypeError** - 如果参数 `node` 不是PattenNode类型。
.. py:method:: mindspore.rewrite.PatternNode.name()
获取PattenNode的名称。
.. py:method:: mindspore.rewrite.PatternNode.set_inputs(inputs)
设置当前PatternNode的输入。
参数:
- **inputs** (list[PatternNode]) - 设置为当前PatternNode的输入。
异常:
- **TypeError** - 如果参数 `inputs` 不是list或者 `inputs` 的成员不是PattenNode类型。
.. py:method:: mindspore.rewrite.PatternNode.type()
获取PattenNode的类型。
.. py:class:: mindspore.rewrite.VarNode()
VarNode是PatternNode的子类其匹配方法始终返回True。
.. warning::
- 这是一组实验性API后续可能修改或删除。
.. py:class:: mindspore.rewrite.Replacement
替换的接口定义。
.. warning::
- 这是一组实验性API后续可能修改或删除。
.. py:method:: mindspore.rewrite.Replacement.build(pattern: PatternNode, is_chain_pattern: bool, matched: OrderedDict)
:abstractmethod:
用于从匹配结果创建替换节点的接口定义。
.. note::
返回值将作为SymbolTree的替换函数的参数返回值应遵循替换函数参数的 `new_nodes` 的约束。请参阅SymbolTree的 `replace` 的文档字符串中的详细信息。
参数:
- **pattern** (PatternNode) - 当前模式的根节点。
- **is_chain_pattern** (bool) - 标记,标记模式是链模式或树模式。
- **matched** (OrderedDict) - 匹配结果,从名称映射到节点的字典。
返回:
作为替换节点的节点实例列表。
.. py:class:: mindspore.rewrite.TreeNodeHelper
TreeNodeHelper用于在从Tree类型节点获取 `symbol_tree` 时打破循环引用。
TreeNodeHelper提供了静态方法 `get_sub_tree` 用于从Tree类型节点获取 `symbol_tree`
.. warning::
- 这是一组实验性API后续可能修改或删除。
.. py:method:: mindspore.rewrite.TreeNodeHelper.get_sub_tree(node: Node)
:staticmethod:
获取Tree类型节点的 `symbol_tree`
参数:
- **node** (Node) - 可以持有子SymbolTree的节点。
返回:
Tree节点中的SymbolTree对象。注意节点的 `symbol_tree` 可能是None在这种情况下方法将返回None。
异常:
- **RuntimeError** - 如果参数 `node` 不是 NodeType.Tree类型。
- **TypeError** - 如果参数 `node` 不是Node类型实例。
.. py:function:: mindspore.rewrite.sparsify(f, arg_types, sparse_rules=None)
模型自动稀疏化接口,将稠密模型转换为稀疏模型。通过 `arg_types` 指定的参数类型,将稀疏参数在模型中传导,并调用相应的稀疏函数。
.. warning::
- 这是一组实验性API后续可能修改或删除。
参数:
- **f** (Cell) - 被稀疏化的网络。
- **arg_types** (Tuple[ArgType] | Dict[int, ArgType]) - `f` 接受的参数类型稀疏CSR/COO、非稀疏等。如果是tuple长度需要和 `f` 的参数数量相等如果是dict每个键值对应一个参数的索引字典里没有表示的参数默认为非稀疏。
- **sparse_rules** (Dict[str, SparseFunc], 可选) - 自定义稀疏规则。默认值: ``None``
.. py:class:: mindspore.rewrite.ArgType
稀疏化的参数类型。
- CSR表示CSRTensor
- COO表示COOTensor
- NONSPARSE表示非稀疏
.. warning::
- 这是一组实验性API后续可能修改或删除。
.. py:class:: mindspore.rewrite.SparseFunc(fn: Union[str, Callable], inputs: Optional[Any] = None, outputs: Optional[Any] = None)
在稀疏化中表示一个稀疏函数。
.. note::
如果 `fn` 是一个包含类型注解的函数,且同时提供了 `inputs`,则类型注解中的输入类型将被忽略。`outputs` 同理。
.. warning::
- 这是一组实验性API后续可能修改或删除。
参数:
- **fn** (Union[str, Callable]) - 稀疏函数如果是字符串表示一个mindspore.ops.function接口或者是任意函数对象。
- **inputs** (Any, 可选) - 函数的输入类型。如果是 ``None`` ,则使用函数本身的类型注解。默认值: ``None``
- **outputs** (Any, 可选) - 函数的输出类型。如果是 ``None`` ,则使用函数本身的类型注解。默认值: ``None``

View File

@ -0,0 +1,380 @@
# 使用ReWrite修改网络
[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.png)](https://gitee.com/mindspore/mindspore/blob/master/docs/api/api_python/samples/rewrite/rewrite_tutorial.md)
此指南展示了[mindspore.rewrite](https://mindspore.cn/docs/zh-CN/master/api_python/mindspore.rewrite.html)模块中API的各种用法。
ReWrite完整示例请参考
[rewrite_example.py](https://gitee.com/mindspore/mindspore/blob/master/docs/api/api_python/rewrite_example.py) 。
该样例代码的主要功能包括怎么通过网络创建SymbolTree并且对SymbolTree中的节点进行插入、删除、替换等操作
其中还包含了对子网络的修改和通过模式匹配进行节点替换。
## 功能介绍
ReWrite模块使用SymbolTree记录一个网络的前向计算过程其中计算过程的每条代码语句会被展开并以节点的形式存储在SymbolTree中。
ReWrite模块提供了一组新的接口用户可以使用这组接口为一个网络创建SymbolTree然后对SymbolTree里的节点进行修改从而实现对
网络前向计算过程的修改。最后得到修改后的网络代码,或者一个新的网络实例。
## 创建SymbolTree
当用户需要使用ReWrite模块对一个网络进行修改时首先需要基于该网络的实例创建一个SymbolTree使用的接口
是 [mindspore.rewrite.SymbolTree.create](https://mindspore.cn/docs/zh-CN/master/api_python/mindspore.rewrite.html#mindspore.rewrite.SymbolTree.create) 。
通过接口 [mindspore.rewrite.SymbolTree.get_code](https://mindspore.cn/docs/zh-CN/master/api_python/mindspore.rewrite.html#mindspore.rewrite.SymbolTree.get_code) 可以查看当前SymbolTree里存储的网络代码。
``` python
import mindspore.nn as nn
from mindspore.rewrite import SymbolTree
class MyNet(nn.Cell):
def __init__(self):
super().__init__()
self.dense = nn.Dense(in_channels=32, out_channels=32, has_bias=False, weight_init="ones")
self.relu = nn.ReLU()
def construct(self, x):
x = self.dense(x)
x = self.relu(x)
return x
net = MyNet()
stree = SymbolTree.create(net)
print(stree.get_code())
```
运行结果如下:
``` log
import sys
sys.path.append('...') # Current working directory
import mindspore
from mindspore import nn
import mindspore.nn as nn
class MyNetOpt(nn.Cell):
def __init__(self, obj):
super().__init__()
for (key, value) in obj.__dict__.items():
setattr(self, key, value)
def construct(self, x):
x = self.dense(x)
x = self.relu(x)
return x
```
可以看到,通过解析网络 `MyNet` SymbolTree里存储的新网络的类名是 `MyNetOpt` ,相较原网络增加了后缀 ``Opt`` 。
同时init函数的参数和内容均发生了改动新增参数 `obj]() 传入的是原始网络的实例,函数里将原始网络的属性信息拷贝到了新的网络里。
新的网络还将当前工作目录保存到 ``sys.path`` 里,从而保证新网络运行时可以搜索到原网络依赖的模块。
通过接口 [mindspore.rewrite.SymbolTree.print_node_tabulate](https://mindspore.cn/docs/zh-CN/master/api_python/mindspore.rewrite.html#mindspore.rewrite.SymbolTree.print_node_tabulate) 可以看到SymbolTree里存储的节点信息及节点拓扑关系。
该接口依赖tabulate模块安装指令为 ``pip install tabulate`` 。
``` python
stree.print_node_tabulate()
```
运行结果如下:
``` log
================================================================================
node type name codes arg providers target users
----------------- ------- ----------------- --------------------- ----------------------
NodeType.Input input_x x [] [[0, [('dense', 0)]]]
NodeType.CallCell dense x = self.dense(x) [[0, ('input_x', 0)]] [[0, [('relu', 0)]]]
NodeType.CallCell relu x = self.relu(x) [[0, ('dense', 0)]] [[0, [('return', 0)]]]
NodeType.Output return return x [[0, ('relu', 0)]] []
==================================================================================
```
可以看到,网络的前向计算过程的每一条语句均被转换为一个节点,其中每一个节点的名称是唯一的。
SymbolTree里记录了各个Node间的拓扑关系即节点的某个输入来自哪个节点的第几个输出以及节点的某个输出被哪些节点的哪个输入使用。
当前向计算过程中存在复杂语句时创建SymbolTree的过程会将语句展开然后再将展开后的每个语句转换为节点。
``` python
import mindspore.nn as nn
from mindspore.rewrite import SymbolTree
class MyNet_2(nn.Cell):
def __init__(self):
super().__init__()
self.dense = nn.Dense(in_channels=32, out_channels=32, has_bias=False, weight_init="ones")
self.relu = nn.ReLU()
def construct(self, x):
x = self.relu(0.5 * self.dense(x))
return x
net = MyNet_2()
stree = SymbolTree.create(net)
stree.print_node_tabulate()
```
运行结果如下:
``` log
================================================================================
node type name codes arg providers target users
----------------- ---------- ------------------------ ------------------------ --------------------------
NodeType.Input input_x x [] [[0, [('dense', 0)]]]
NodeType.CallCell dense dense = self.dense(x) [[0, ('input_x', 0)]] [[0, [('binop_mult', 1)]]]
NodeType.MathOps binop_mult mult_var = (0.5 * dense) [[1, ('dense', 0)]] [[0, [('relu', 0)]]]
NodeType.CallCell relu x = self.relu(mult_var) [[0, ('binop_mult', 0)]] [[0, [('return', 0)]]]
NodeType.Output return return x [[0, ('relu', 0)]] []
==================================================================================
```
可以看到前向计算过程中写在同一行的dense操作、乘法操作和relu操作被展开为三行代码然后被转换为三个对应节点。
## 插入节点
当需要在网络的前向计算过程中插入一行新的代码时,可以先使用接口 [mindspore.rewrite.Node.create_call_cell](https://mindspore.cn/docs/zh-CN/master/api_python/mindspore.rewrite.html#mindspore.rewrite.Node.create_call_cell) 创建一个新
的节点,然后使用接口 [mindspore.rewrite.SymbolTree.insert](https://mindspore.cn/docs/zh-CN/master/api_python/mindspore.rewrite.html#mindspore.rewrite.SymbolTree.insert) 将创建的节点插入到SymbolTree内。
``` python
from mindspore.rewrite import SymbolTree, Node, ScopedValue
net = MyNet()
stree = SymbolTree.create(net)
new_relu_cell = nn.ReLU()
new_node = Node.create_call_cell(cell=new_relu_cell, targets=["x"],
args=[ScopedValue.create_naming_value("x")], name="new_relu")
dense_node = stree.get_node("dense")
stree.insert(stree.after(dense_node), new_node)
stree.print_node_tabulate()
```
在该样例中,插入节点的流程如下:
1. 首先创建了一个新的节点使用的Cell是 ``nn.ReLU()`` ,输入输出均为 ``"x"`` ,节点名是 ``"new_relu"`` 。
2. 接着通过 [mindspore.rewrite.SymbolTree.get_node](https://mindspore.cn/docs/zh-CN/master/api_python/mindspore.rewrite.html#mindspore.rewrite.SymbolTree.get_node) 方法获取dense节点。
3. 最后通过 [mindspore.rewrite.SymbolTree.insert](https://mindspore.cn/docs/zh-CN/master/api_python/mindspore.rewrite.html#mindspore.rewrite.SymbolTree.insert) 方法将新创建的节点插入到dense节点后面。
运行结果如下:
``` log
================================================================================
node type name codes arg providers target users
----------------- -------- -------------------- ---------------------- ------------------------
NodeType.Input input_x x [] [[0, [('dense', 0)]]]
NodeType.CallCell dense x = self.dense(x) [[0, ('input_x', 0)]] [[0, [('new_relu', 0)]]]
NodeType.CallCell new_relu x = self.new_relu(x) [[0, ('dense', 0)]] [[0, [('relu', 0)]]]
NodeType.CallCell relu x = self.relu(x) [[0, ('new_relu', 0)]] [[0, [('return', 0)]]]
NodeType.Output return return x [[0, ('relu', 0)]] []
==================================================================================
```
可以看到新的new_relu节点插入到dense节点和relu节点间节点的拓扑结构随着节点插入自动更新。
其中,新节点对应代码里的 `self.new_relu` 定义在新网络的init函数里使用传入的 `new_relu_cell` 作为实例。
除了使用 [mindspore.rewrite.SymbolTree.get_node](https://mindspore.cn/docs/zh-CN/master/api_python/mindspore.rewrite.html#mindspore.rewrite.SymbolTree.get_node) 方法获取节点来指定插入位置,还可以通过 [mindspore.rewrite.SymbolTree.nodes](https://mindspore.cn/docs/zh-CN/master/api_python/mindspore.rewrite.html#mindspore.rewrite.SymbolTree.nodes) 来遍历节点,并使用 [mindspore.rewrite.Node.get_instance_type](https://mindspore.cn/docs/zh-CN/master/api_python/mindspore.rewrite.html#mindspore.rewrite.Node.get_instance_type) 基于节点对应实例的类型来获取节点,确定插入位置。
``` python
for node in stree.nodes():
if node.get_instance_type() == nn.Dense:
stree.insert(stree.after(node), new_node)
```
如果希望插入新代码的输出不复用原始网络里的变量,可以在创建节点时使用 [mindspore.rewrite.SymbolTree.unique_name](https://mindspore.cn/docs/zh-CN/master/api_python/mindspore.rewrite.html#mindspore.rewrite.SymbolTree.unique_name) 得
到一个SymbolTree内不重名的变量名作为节点的输出。
然后在插入节点前,通过使用 [mindspore.rewrite.Node.set_arg](https://mindspore.cn/docs/zh-CN/master/api_python/mindspore.rewrite.html#mindspore.rewrite.Node.set_arg) 修改节点输入变量名,设置哪些节点使用新的节点输出作为输入。
``` python
from mindspore.rewrite import SymbolTree, Node, ScopedValue
net = MyNet()
stree = SymbolTree.create(net)
new_relu_cell = nn.ReLU()
new_node = Node.create_call_cell(cell=new_relu_cell, targets=[stree.unique_name("x")],
args=[ScopedValue.create_naming_value("x")], name="new_relu")
dense_node = stree.get_node("dense")
stree.insert(stree.after(dense_node), new_node)
old_relu_node = stree.get_node("relu")
old_relu_node.set_arg(0, new_node.get_targets()[0])
stree.print_node_tabulate()
```
在该样例中,创建新节点时 `targets` 参数的值进行了不重名的处理然后将旧的relu节点的输入改为新节点的输出。
运行结果如下:
``` log
================================================================================
node type name codes arg providers target users
----------------- -------- ---------------------- ---------------------- ------------------------
NodeType.Input input_x x [] [[0, [('dense', 0)]]]
NodeType.CallCell dense x = self.dense(x) [[0, ('input_x', 0)]] [[0, [('new_relu', 0)]]]
NodeType.CallCell new_relu x_1 = self.new_relu(x) [[0, ('dense', 0)]] [[0, [('relu', 0)]]]
NodeType.CallCell relu x = self.relu(x_1) [[0, ('new_relu', 0)]] [[0, [('return', 0)]]]
NodeType.Output return return x [[0, ('relu', 0)]] []
==================================================================================
```
可以看到,新节点的输出变量名是一个不重名的名称 ``x_1`` 且旧的relu节点使用 ``x_1`` 作为输入。
## 删除节点
当需要在网络的前向计算过程中删除一行代码时,可以使用接口 [mindspore.rewrite.SymbolTree.erase](https://mindspore.cn/docs/zh-CN/master/api_python/mindspore.rewrite.html#mindspore.rewrite.SymbolTree.erase) 来删除节点。
节点删除后,符号树内剩余节点的拓扑关系会依据删除后的代码情况自动更新。
因此,当待删除的节点的输出被别的节点使用时,节点删除后,需要注意剩余节点的拓扑关系是否符合设计预期。
如果待删除节点的前面存在某个节点的输出名和待删除节点的输出名重名,删除节点后,后续使用该输出名作为输入的节点,自动使用前面那个节点
的输出作为输入。拓扑关系会按照该策略更新。
``` python
from mindspore.rewrite import SymbolTree, Node, ScopedValue
net = MyNet()
stree = SymbolTree.create(net)
relu_node = stree.get_node("relu")
stree.erase(relu_node)
stree.print_node_tabulate()
```
运行结果如下:
``` log
================================================================================
node type name codes arg providers target users
----------------- ------- ----------------- --------------------- ----------------------
NodeType.Input input_x x [] [[0, [('dense', 0)]]]
NodeType.CallCell dense x = self.dense(x) [[0, ('input_x', 0)]] [[0, [('return', 0)]]]
NodeType.Output return return x [[0, ('dense', 0)]] []
==================================================================================
```
可以看到因为dense结点的输出和relu结点的输出同名删除relu节点后返回值使用的是dense节点的输出。
如果待删除节点的前面不存在和待删除节点同名的输出,则需要用户先修改后续使用该输出作为输入的节点,更新参数名,然后再
删除节点,以避免删除节点后发生使用了未定义变量的错误。
``` python
import mindspore.nn as nn
from mindspore.rewrite import SymbolTree
class MyNet_3(nn.Cell):
def __init__(self):
super().__init__()
self.dense = nn.Dense(in_channels=32, out_channels=32, has_bias=False, weight_init="ones")
self.relu = nn.ReLU()
def construct(self, x):
y = self.dense(x)
z = self.relu(y)
return z
net = MyNet_3()
stree = SymbolTree.create(net)
relu_node = stree.get_node("relu")
for node in relu_node.get_users():
node.set_arg(0, relu_node.get_args()[0])
stree.erase(relu_node)
stree.print_node_tabulate()
```
在该样例中拿到relu节点后先使用接口 [mindspore.rewrite.Node.get_users](https://mindspore.cn/docs/zh-CN/master/api_python/mindspore.rewrite.html#mindspore.rewrite.Node.get_users) 遍历使用relu节点的输出作为输入的节点将这些
节点的输入都改为relu节点的输入然后再删除relu节点。这样的话后续使用了relu节点输出 ``z`` 的地方就都改为使用relu节点输入 ``y`` 了。
具体的参数名修改策略取决于实际场景需求。
运行结果如下:
``` log
================================================================================
node type name codes arg providers target users
----------------- ------- ----------------- --------------------- ----------------------
NodeType.Input input_x x [] [[0, [('dense', 0)]]]
NodeType.CallCell dense y = self.dense(x) [[0, ('input_x', 0)]] [[0, [('return', 0)]]]
NodeType.Output return return y [[0, ('dense', 0)]] []
==================================================================================
```
可以看到删除relu节点后最后一个return节点的值从 ``z`` 被更新为 ``y`` 。
## 替换节点
当需要在网络的前向计算过程中替换代码时,可以使用接口 [mindspore.rewrite.SymbolTree.replace](https://mindspore.cn/docs/zh-CN/master/api_python/mindspore.rewrite.html#mindspore.rewrite.SymbolTree.replace) 来替换节点。
``` python
from mindspore.rewrite import SymbolTree, Node, ScopedValue
net = MyNet()
stree = SymbolTree.create(net)
new_relu_cell = nn.ReLU()
new_node = Node.create_call_cell(cell=new_relu_cell, targets=["x"],
args=[ScopedValue.create_naming_value("x")], name="new_relu")
relu_node = stree.get_node("relu")
stree.replace(relu_node, [new_node])
stree.print_node_tabulate()
```
该样例将原始网络里的relu节点替换为new_relu节点运行结果如下
``` log
================================================================================
node type name codes arg providers target users
----------------- -------- -------------------- ---------------------- ------------------------
NodeType.Input input_x x [] [[0, [('dense', 0)]]]
NodeType.CallCell dense x = self.dense(x) [[0, ('input_x', 0)]] [[0, [('new_relu', 0)]]]
NodeType.CallCell new_relu x = self.new_relu(x) [[0, ('dense', 0)]] [[0, [('return', 0)]]]
NodeType.Output return return x [[0, ('new_relu', 0)]] []
==================================================================================
```
如果替换的新节点的输出和被替换节点的输出名不一致,需要注意维护好替换后的节点间的拓扑关系,即先修改后续使用了被替换节点的输出的节点,
更新这些节点的参数名,然后再进行节点替换操作。
``` python
from mindspore.rewrite import SymbolTree, Node, ScopedValue
net = MyNet()
stree = SymbolTree.create(net)
# Update the parameter names of subsequent nodes
relu_node = stree.get_node("relu")
for node in relu_node.get_users():
node.set_arg(0, "y1")
# Create two new nodes
new_relu_cell = nn.ReLU()
new_node = Node.create_call_cell(cell=new_relu_cell, targets=["y1"],
args=[ScopedValue.create_naming_value("x")], name="new_relu_1")
new_relu_cell_2 = nn.ReLU()
new_node_2 = Node.create_call_cell(cell=new_relu_cell_2, targets=["y2"],
args=[ScopedValue.create_naming_value("x")], name="new_relu_2")
# Replace relu node with two new nodes
stree.replace(relu_node, [new_node, new_node_2])
stree.print_node_tabulate()
```
该用例将relu节点替换为两个新的节点其中第一个节点的输出 ``y1`` 作为返回值更新return节点。运行结果如下
``` log
================================================================================
node type name codes arg providers target users
----------------- ---------- ----------------------- ---------------------- -------------------------------------------
NodeType.Input input_x x [] [[0, [('dense', 0)]]]
NodeType.CallCell dense x = self.dense(x) [[0, ('input_x', 0)]] [[0, [('new_relu', 0), ('new_relu_1', 0)]]]
NodeType.CallCell new_relu y1 = self.new_relu(x) [[0, ('dense', 0)]] [[0, [('return', 0)]]]
NodeType.CallCell new_relu_1 y2 = self.new_relu_1(x) [[0, ('dense', 0)]] []
NodeType.Output return return y1 [[0, ('new_relu', 0)]] []
==================================================================================
```
可以看出relu节点被成功替换为两个新节点返回值也被更新为第一个新节点的输出。
## 返回新网络
当对网络修改完毕后,就可以使用接口 [mindspore.rewrite.SymbolTree.get_network](https://mindspore.cn/docs/zh-CN/master/api_python/mindspore.rewrite.html#mindspore.rewrite.SymbolTree.get_network) 得到修改后的网络实例了。
``` python
from mindspore import Tensor
from mindspore.common import dtype as mstype
import numpy as np
new_net = stree.get_network()
inputs = Tensor(np.ones([1, 1, 32, 32]), mstype.float32)
outputs = new_net(inputs)
```
调用该接口后Rewrite模块会先在当前工作目录的rewritten_network文件夹下生成修改后的网络对应的脚本文件然后使用该脚本文件创建新的网络实例
原网络的实例作为参数使用。新的网络实例可以直接用于计算和训练。

View File

@ -4,425 +4,9 @@ mindspore.rewrite
The ReWrite module in MindSpore provides users with the ability to modify the network's forward computation
process based on custom rules, such as inserting, deleting, and replacing statements.
Tutorial for use
-----------------
For a complete ReWrite example, refer to
`rewrite_example.py <https://gitee.com/mindspore/mindspore/blob/master/docs/api/api_python_en/rewrite_example.py>`_
The main functions of the sample code include: how to create a SymbolTree through the network, and how to insert, delete,
and replace the nodes in the SymbolTree. It also includes the modification of the subnet and node replacement through pattern
matching.
Function Introduction
^^^^^^^^^^^^^^^^^^^^^^
ReWrite module uses SymbolTree to record the forward computation of a network, where each code statement of the
forward computation process is expanded and stored in the SymbolTree as nodes.
The ReWrite module provides a new set of interfaces that users can use to create a SymbolTree for a network and then
modify the nodes in the SymbolTree to achieve the network forward computation process modification. Finally, a modified
network code, or a new network instance can be obtained.
Create A SymbolTree
^^^^^^^^^^^^^^^^^^^^
When we need to modify a network using the ReWrite module, we first needs to create a SymbolTree based on the instance
of the network, using the interface :func:`mindspore.rewrite.SymbolTree.create` .
Through the use of the interface :func:`mindspore.rewrite.SymbolTree.get_code` , we can view the network code currently
stored in SymbolTree.
.. code-block:: python
import mindspore.nn as nn
from mindspore.rewrite import SymbolTree
class MyNet(nn.Cell):
def __init__(self):
super().__init__()
self.dense = nn.Dense(in_channels=32, out_channels=32, has_bias=False, weight_init="ones")
self.relu = nn.ReLU()
def construct(self, x):
x = self.dense(x)
x = self.relu(x)
return x
net = MyNet()
stree = SymbolTree.create(net)
print(stree.get_code())
The results are as follows:
.. code-block:: python
import sys
sys.path.append('...') # Current working directory
import mindspore
from mindspore import nn
import mindspore.nn as nn
class MyNetOpt(nn.Cell):
def __init__(self, obj):
super().__init__()
for (key, value) in obj.__dict__.items():
setattr(self, key, value)
def construct(self, x):
x = self.dense(x)
x = self.relu(x)
return x
It can be seen that by parsing the network `MyNet` , the class name of the new network stored in SymbolTree is `MyNetOpt` ,
which adds the suffix ``Opt`` to the original network.
At the same time, the parameters and content of the init function have been changed. The new parameter `obj` is passed into
the instance of the original network, and the attribute information of the original network is copied to the new network in
the function.
The new network also saves the current working directory to ``sys.path`` , ensuring that modules that the original network
depends on can be searched for when running on the new network.
By using the interface :func:`mindspore.rewrite.SymbolTree.print_node_tabulate` , we can see the node information and node
topology relationships stored in the SymbolTree.
This interface depends on the tabulate module, and the installation command is: ``pip install tabulate`` .
.. code-block:: python
stree.print_node_tabulate()
The results are as follows:
.. code-block::
================================================================================
node type name codes arg providers target users
----------------- ------- ----------------- --------------------- ----------------------
NodeType.Input input_x x [] [[0, [('dense', 0)]]]
NodeType.CallCell dense x = self.dense(x) [[0, ('input_x', 0)]] [[0, [('relu', 0)]]]
NodeType.CallCell relu x = self.relu(x) [[0, ('dense', 0)]] [[0, [('return', 0)]]]
NodeType.Output return return x [[0, ('relu', 0)]] []
==================================================================================
It can be seen that each statement in the network's forward computation process is converted to a node, where the name
of each node is unique.
The SymbolTree records the topological relationship between each node, that is, the output of which node an input comes
from, and the output of a node is used by which input of which node.
When there are complex statements in the forward computation process, the statements are expanded during the creation
of SymbolTree, and then each expanded statement is converted to a node.
.. code-block:: python
import mindspore.nn as nn
from mindspore.rewrite import SymbolTree
class MyNet_2(nn.Cell):
def __init__(self):
super().__init__()
self.dense = nn.Dense(in_channels=32, out_channels=32, has_bias=False, weight_init="ones")
self.relu = nn.ReLU()
def construct(self, x):
x = self.relu(0.5 * self.dense(x))
return x
net = MyNet_2()
stree = SymbolTree.create(net)
stree.print_node_tabulate()
The results are as follows:
.. code-block::
================================================================================
node type name codes arg providers target users
----------------- ---------- ------------------------ ------------------------ --------------------------
NodeType.Input input_x x [] [[0, [('dense', 0)]]]
NodeType.CallCell dense dense = self.dense(x) [[0, ('input_x', 0)]] [[0, [('binop_mult', 1)]]]
NodeType.MathOps binop_mult mult_var = (0.5 * dense) [[1, ('dense', 0)]] [[0, [('relu', 0)]]]
NodeType.CallCell relu x = self.relu(mult_var) [[0, ('binop_mult', 0)]] [[0, [('return', 0)]]]
NodeType.Output return return x [[0, ('relu', 0)]] []
==================================================================================
It can be seen that the dense, multiplication, and relu operations written on the same line during forward computing are
expanded into three lines of code and then converted into three corresponding nodes.
Insert Nodes
^^^^^^^^^^^^^
When we need to insert a new line of code during the forward computation of the network, we can first create a new node
using interface :func:`mindspore.rewrite.Node.create_call_cell` , and then insert the created node into SymbolTree
using interface :func:`mindspore.rewrite.SymbolTree.insert` .
.. code-block:: python
from mindspore.rewrite import SymbolTree, Node, ScopedValue
net = MyNet()
stree = SymbolTree.create(net)
new_relu_cell = nn.ReLU()
new_node = Node.create_call_cell(cell=new_relu_cell, targets=["x"],
args=[ScopedValue.create_naming_value("x")], name="new_relu")
dense_node = stree.get_node("dense")
stree.insert(stree.after(dense_node), new_node)
stree.print_node_tabulate()
In this example, the process for inserting a node is as follows:
1. Firstly, a new node is created. The Cell used is ``nn.ReLU()`` , the input and output are ``"x"`` , and the node name is ``"new_relu"`` .
2. Then the dense node is fetched by using :func:`mindspore.rewrite.SymbolTree.get_node` .
3. Finally, the newly created node is inserted after the dense node through :func:`mindspore.rewrite.SymbolTree.insert` .
The results are as follows:
.. code-block::
================================================================================
node type name codes arg providers target users
----------------- -------- -------------------- ---------------------- ------------------------
NodeType.Input input_x x [] [[0, [('dense', 0)]]]
NodeType.CallCell dense x = self.dense(x) [[0, ('input_x', 0)]] [[0, [('new_relu', 0)]]]
NodeType.CallCell new_relu x = self.new_relu(x) [[0, ('dense', 0)]] [[0, [('relu', 0)]]]
NodeType.CallCell relu x = self.relu(x) [[0, ('new_relu', 0)]] [[0, [('return', 0)]]]
NodeType.Output return return x [[0, ('relu', 0)]] []
==================================================================================
It can be seen that the new new_relu node is inserted between the dense node and the relu node, and the topology of
node is automatically updated with the node insertion.
The definition of `self.new_relu` in the code of new node is saved in the init function of the new network, using
parameter `new_relu_cell` as the instance.
In addition to getting nodes using :func:`mindspore.rewrite.SymbolTree.get_node` to specify the insertion location, we can
also iterate through nodes by :func:`mindspore.rewrite.SymbolTree.nodes` and use :func:`mindspore.rewrite.SymbolTree.get_instance_type`
to get the node and determine the insertion position based on the type of corresponding instance of node.
.. code-block:: python
for node in stree.nodes():
if node.get_instance_type() == nn.Dense:
stree.insert(stree.after(node), new_node)
If we want the output of new code to be inserted does not reuse variables from the original network, we can
use :func:`mindspore.rewrite.SymbolTree.unique_name` to get an variable name that are not duplicated in the SymbolTree
as the output of node when creating nodes.
Then, before inserting the node, we can modify the node input variable name by using :func:`mindspore.rewrite.Node.set_arg`
to set which nodes use the new node output as input.
.. code-block:: python
from mindspore.rewrite import SymbolTree, Node, ScopedValue
net = MyNet()
stree = SymbolTree.create(net)
new_relu_cell = nn.ReLU()
new_node = Node.create_call_cell(cell=new_relu_cell, targets=[stree.unique_name("x")],
args=[ScopedValue.create_naming_value("x")], name="new_relu")
dense_node = stree.get_node("dense")
stree.insert(stree.after(dense_node), new_node)
old_relu_node = stree.get_node("relu")
old_relu_node.set_arg(0, new_node.get_targets()[0])
stree.print_node_tabulate()
In this example, when creating a new node, the value of the `targets` parameter is treated without duplication,
and the input of old relu node is changed to the output of new node.
The results are as follows:
.. code-block::
================================================================================
node type name codes arg providers target users
----------------- -------- ---------------------- ---------------------- ------------------------
NodeType.Input input_x x [] [[0, [('dense', 0)]]]
NodeType.CallCell dense x = self.dense(x) [[0, ('input_x', 0)]] [[0, [('new_relu', 0)]]]
NodeType.CallCell new_relu x_1 = self.new_relu(x) [[0, ('dense', 0)]] [[0, [('relu', 0)]]]
NodeType.CallCell relu x = self.relu(x_1) [[0, ('new_relu', 0)]] [[0, [('return', 0)]]]
NodeType.Output return return x [[0, ('relu', 0)]] []
==================================================================================
It can be seen that the output variable name of new node is an unnamed name ``x_1`` , and the old relu node uses ``x_1`` as input.
Delete Nodes
^^^^^^^^^^^^^
When we need to delete a line of code during the forward computation of the network, we can use the interface
:func:`mindspore.rewrite.SymbolTree.erase` to delete the node.
After the node is deleted, the topological relationship of the remaining nodes in the symbol tree will be automatically
updated according to the code of network after deletion.
Therefore, when the output of node to be deleted is used by other nodes, we need to pay attention to whether the topological
relationship of the remaining nodes meets the design expectations after the node is deleted.
If a node exists in front of the node to be deleted that has the same output name as the node to be deleted, after the node
is deleted, the output of the previous node is automatically used as input for the node that uses the output name as the input.
The topology relationship is updated according to this policy.
.. code-block:: python
from mindspore.rewrite import SymbolTree, Node, ScopedValue
net = MyNet()
stree = SymbolTree.create(net)
relu_node = stree.get_node("relu")
stree.erase(relu_node)
stree.print_node_tabulate()
The results are as follows:
.. code-block::
================================================================================
node type name codes arg providers target users
----------------- ------- ----------------- --------------------- ----------------------
NodeType.Input input_x x [] [[0, [('dense', 0)]]]
NodeType.CallCell dense x = self.dense(x) [[0, ('input_x', 0)]] [[0, [('return', 0)]]]
NodeType.Output return return x [[0, ('dense', 0)]] []
==================================================================================
It can be seen that because the output of dense node and the output of relu node have the same name, after deleting
the relu node, the return value uses the output of the dense node.
If there is no node that has the same output name as the node to be deleted in front of the node to be deleted, we need
to modify subsequent nodes that uses this output as input by updating the input names, and then delete the node, in order
to avoid errors using undefined variables after deleting the node.
.. code-block:: python
import mindspore.nn as nn
from mindspore.rewrite import SymbolTree
class MyNet_3(nn.Cell):
def __init__(self):
super().__init__()
self.dense = nn.Dense(in_channels=32, out_channels=32, has_bias=False, weight_init="ones")
self.relu = nn.ReLU()
def construct(self, x):
y = self.dense(x)
z = self.relu(y)
return z
net = MyNet_3()
stree = SymbolTree.create(net)
relu_node = stree.get_node("relu")
for node in relu_node.get_users():
node.set_arg(0, relu_node.get_args()[0])
stree.erase(relu_node)
stree.print_node_tabulate()
In this example, after getting the relu node, first we use the interface :func:`mindspore.rewrite.Node.get_users` to
iterate through the nodes that use the output of relu node as input, change the input of these nodes to the input of relu
node, and then delete the relu node. In this case, the subsequent use of the relu node output ``z`` will be changed to
the relu node input ``y`` .
The specific parameter name modification strategy depends on the actual scenario requirements.
The results are as follows:
.. code-block::
================================================================================
node type name codes arg providers target users
----------------- ------- ----------------- --------------------- ----------------------
NodeType.Input input_x x [] [[0, [('dense', 0)]]]
NodeType.CallCell dense y = self.dense(x) [[0, ('input_x', 0)]] [[0, [('return', 0)]]]
NodeType.Output return return y [[0, ('dense', 0)]] []
==================================================================================
It can be seen that after deleting the relu node, the value of the last return node is updated from ``z`` to ``y`` .
Replace Nodes
^^^^^^^^^^^^^^
When we need to replace code during the forward computation of network, we can replace the node with the
interface :func:`mindspore.rewrite.SymbolTree.replace` .
.. code-block:: python
from mindspore.rewrite import SymbolTree, Node, ScopedValue
net = MyNet()
stree = SymbolTree.create(net)
new_relu_cell = nn.ReLU()
new_node = Node.create_call_cell(cell=new_relu_cell, targets=["x"],
args=[ScopedValue.create_naming_value("x")], name="new_relu")
relu_node = stree.get_node("relu")
stree.replace(relu_node, [new_node])
stree.print_node_tabulate()
This example replaces relu node in the original network with new_relu node. The results are as follows:
.. code-block::
================================================================================
node type name codes arg providers target users
----------------- -------- -------------------- ---------------------- ------------------------
NodeType.Input input_x x [] [[0, [('dense', 0)]]]
NodeType.CallCell dense x = self.dense(x) [[0, ('input_x', 0)]] [[0, [('new_relu', 0)]]]
NodeType.CallCell new_relu x = self.new_relu(x) [[0, ('dense', 0)]] [[0, [('return', 0)]]]
NodeType.Output return return x [[0, ('new_relu', 0)]] []
==================================================================================
If the output of the replaced node and the output name of the replaced node are inconsistent, we need to pay attention
to maintaining the topological relationship between nodes after replacement, that is, first modify the subsequent nodes that
uses the output of the replaced node, update the parameter names of these nodes, and then perform the node replacement operation.
.. code-block:: python
from mindspore.rewrite import SymbolTree, Node, ScopedValue
net = MyNet()
stree = SymbolTree.create(net)
# Update the parameter names of subsequent nodes
relu_node = stree.get_node("relu")
for node in relu_node.get_users():
node.set_arg(0, "y1")
# Create two new nodes
new_relu_cell = nn.ReLU()
new_node = Node.create_call_cell(cell=new_relu_cell, targets=["y1"],
args=[ScopedValue.create_naming_value("x")], name="new_relu_1")
new_relu_cell_2 = nn.ReLU()
new_node_2 = Node.create_call_cell(cell=new_relu_cell_2, targets=["y2"],
args=[ScopedValue.create_naming_value("x")], name="new_relu_2")
# Replace relu node with two new nodes
stree.replace(relu_node, [new_node, new_node_2])
stree.print_node_tabulate()
The example replaces relu node with two new nodes, where the output of first node ``y1`` is used as the return value in the
return node. The results are as follows:
.. code-block::
================================================================================
node type name codes arg providers target users
----------------- ---------- ----------------------- ---------------------- -------------------------------------------
NodeType.Input input_x x [] [[0, [('dense', 0)]]]
NodeType.CallCell dense x = self.dense(x) [[0, ('input_x', 0)]] [[0, [('new_relu', 0), ('new_relu_1', 0)]]]
NodeType.CallCell new_relu y1 = self.new_relu(x) [[0, ('dense', 0)]] [[0, [('return', 0)]]]
NodeType.CallCell new_relu_1 y2 = self.new_relu_1(x) [[0, ('dense', 0)]] []
NodeType.Output return return y1 [[0, ('new_relu', 0)]] []
==================================================================================
It can be seen that the relu node was successfully replaced with two new nodes, and the return value was also
updated to the output of the first new node.
Return A New Network
^^^^^^^^^^^^^^^^^^^^^
When the network is modified, we can use the interface :func:`mindspore.rewrite.SymbolTree.get_network` to get the
modified network instance.
.. code-block:: python
new_net = stree.get_network()
inputs = Tensor(np.ones([1, 1, 32, 32]), mstype.float32)
outputs = new_net(inputs)
After calling this interface, rewrite module will first generate a script file corresponding to the modified network in the
rewritten_network folder of the current working directory, and then use the script file to create a new network instance,
and use the original network instance as a parameter. New network instances can be used directly for compute and training.
For a quick start of using ReWrite, please refer to `Modifying Network With ReWrite <https://www.mindspore.cn/docs/en/master/api_python/samples/rewrite/rewrite_tutorial.html>`_ .
.. automodule:: mindspore.rewrite
:exclude-members: SparseFunc
:exclude-members: SparseFunc, PatternEngine, PatternNode, VarNode, Replacement, TreeNodeHelper, sparsify, ArgType
:members:
.. autoclass:: mindspore.rewrite.SparseFunc

View File

@ -0,0 +1,415 @@
# Modifying Network With ReWrite
[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.png)](https://gitee.com/mindspore/mindspore/blob/master/docs/api/api_python_en/samples/rewrite/rewrite_tutorial.md)
This example illustrates the various usages of APIs available in the [mindspore.rewrite](https://www.mindspore.cn/docs/en/master/api_python/mindspore.rewrite.html) module.
For a complete ReWrite example, refer to
[rewrite_example.py](https://gitee.com/mindspore/mindspore/blob/master/docs/api/api_python_en/rewrite_example.py) .
The main functions of the sample code include: how to create a SymbolTree through the network, and how to insert, delete,
and replace the nodes in the SymbolTree. It also includes the modification of the subnet and node replacement through pattern
matching.
## Function Introduction
ReWrite module uses SymbolTree to record the forward computation of a network, where each code statement of the
forward computation process is expanded and stored in the SymbolTree as nodes.
The ReWrite module provides a new set of interfaces that users can use to create a SymbolTree for a network and then
modify the nodes in the SymbolTree to achieve the network forward computation process modification. Finally, a modified
network code, or a new network instance can be obtained.
## Creating A SymbolTree
When we need to modify a network using the ReWrite module, we first need to create a SymbolTree based on the instance
of the network, using the interface [mindspore.rewrite.SymbolTree.create](https://mindspore.cn/docs/en/master/api_python/mindspore.rewrite.html#mindspore.rewrite.SymbolTree.create) .
Through the use of the interface [mindspore.rewrite.SymbolTree.get_code](https://mindspore.cn/docs/en/master/api_python/mindspore.rewrite.html#mindspore.rewrite.SymbolTree.get_code), we can view the network code currently
stored in SymbolTree.
``` python
import mindspore.nn as nn
from mindspore.rewrite import SymbolTree
class MyNet(nn.Cell):
def __init__(self):
super().__init__()
self.dense = nn.Dense(in_channels=32, out_channels=32, has_bias=False, weight_init="ones")
self.relu = nn.ReLU()
def construct(self, x):
x = self.dense(x)
x = self.relu(x)
return x
net = MyNet()
stree = SymbolTree.create(net)
print(stree.get_code())
```
The results are as follows:
``` log
import sys
sys.path.append('...') # Current working directory
import mindspore
from mindspore import nn
import mindspore.nn as nn
class MyNetOpt(nn.Cell):
def __init__(self, obj):
super().__init__()
for (key, value) in obj.__dict__.items():
setattr(self, key, value)
def construct(self, x):
x = self.dense(x)
x = self.relu(x)
return x
```
It can be seen that by parsing the network `MyNet` , the class name of the new network stored in SymbolTree is `MyNetOpt` ,
which adds the suffix ``Opt`` to the original network.
At the same time, the parameters and content of the init function have been changed. The new parameter `obj` is passed into
the instance of the original network, and the attribute information of the original network is copied to the new network in
the function.
The new network also saves the current working directory to ``sys.path`` , ensuring that modules that the original network
depends on can be searched for when running on the new network.
By using the interface [mindspore.rewrite.SymbolTree.print_node_tabulate](https://mindspore.cn/docs/en/master/api_python/mindspore.rewrite.html#mindspore.rewrite.SymbolTree.print_node_tabulate) , we can see the node information and node
topology relationships stored in the SymbolTree.
This interface depends on the tabulate module, and the installation command is: ``pip install tabulate`` .
``` python
stree.print_node_tabulate()
```
The results are as follows:
``` log
================================================================================
node type name codes arg providers target users
----------------- ------- ----------------- --------------------- ----------------------
NodeType.Input input_x x [] [[0, [('dense', 0)]]]
NodeType.CallCell dense x = self.dense(x) [[0, ('input_x', 0)]] [[0, [('relu', 0)]]]
NodeType.CallCell relu x = self.relu(x) [[0, ('dense', 0)]] [[0, [('return', 0)]]]
NodeType.Output return return x [[0, ('relu', 0)]] []
==================================================================================
```
It can be seen that each statement in the network's forward computation process is converted to a node, where the name
of each node is unique.
The SymbolTree records the topological relationship between each node, that is, the output of which node an input comes
from, and the output of a node is used by which input of which node.
When there are complex statements in the forward computation process, the statements are expanded during the creation
of SymbolTree, and then each expanded statement is converted to a node.
``` python
import mindspore.nn as nn
from mindspore.rewrite import SymbolTree
class MyNet_2(nn.Cell):
def __init__(self):
super().__init__()
self.dense = nn.Dense(in_channels=32, out_channels=32, has_bias=False, weight_init="ones")
self.relu = nn.ReLU()
def construct(self, x):
x = self.relu(0.5 * self.dense(x))
return x
net = MyNet_2()
stree = SymbolTree.create(net)
stree.print_node_tabulate()
```
The results are as follows:
``` log
================================================================================
node type name codes arg providers target users
----------------- ---------- ------------------------ ------------------------ --------------------------
NodeType.Input input_x x [] [[0, [('dense', 0)]]]
NodeType.CallCell dense dense = self.dense(x) [[0, ('input_x', 0)]] [[0, [('binop_mult', 1)]]]
NodeType.MathOps binop_mult mult_var = (0.5 * dense) [[1, ('dense', 0)]] [[0, [('relu', 0)]]]
NodeType.CallCell relu x = self.relu(mult_var) [[0, ('binop_mult', 0)]] [[0, [('return', 0)]]]
NodeType.Output return return x [[0, ('relu', 0)]] []
==================================================================================
```
It can be seen that the dense, multiplication, and relu operations written on the same line during forward computing are
expanded into three lines of code and then converted into three corresponding nodes.
## Inserting Nodes
When we need to insert a new line of code during the forward computation of the network, we can first create a new node
using interface [mindspore.rewrite.Node.create_call_cell](https://mindspore.cn/docs/en/master/api_python/mindspore.rewrite.html#mindspore.rewrite.Node.create_call_cell) , and then insert the created node into SymbolTree
using interface [mindspore.rewrite.SymbolTree.insert](https://mindspore.cn/docs/en/master/api_python/mindspore.rewrite.html#mindspore.rewrite.SymbolTree.insert) .
``` python
from mindspore.rewrite import SymbolTree, Node, ScopedValue
net = MyNet()
stree = SymbolTree.create(net)
new_relu_cell = nn.ReLU()
new_node = Node.create_call_cell(cell=new_relu_cell, targets=["x"],
args=[ScopedValue.create_naming_value("x")], name="new_relu")
dense_node = stree.get_node("dense")
stree.insert(stree.after(dense_node), new_node)
stree.print_node_tabulate()
```
In this example, the process for inserting a node is as follows:
1. Firstly, a new node is created. The Cell used is ``nn.ReLU()`` , the input and output are ``"x"`` , and the node name is ``"new_relu"`` .
2. Then the dense node is fetched by using [mindspore.rewrite.SymbolTree.get_node](https://mindspore.cn/docs/en/master/api_python/mindspore.rewrite.html#mindspore.rewrite.SymbolTree.get_node) .
3. Finally, the newly created node is inserted after the dense node through [mindspore.rewrite.SymbolTree.insert](https://mindspore.cn/docs/en/master/api_python/mindspore.rewrite.html#mindspore.rewrite.SymbolTree.insert) .
The results are as follows:
``` log
================================================================================
node type name codes arg providers target users
----------------- -------- -------------------- ---------------------- ------------------------
NodeType.Input input_x x [] [[0, [('dense', 0)]]]
NodeType.CallCell dense x = self.dense(x) [[0, ('input_x', 0)]] [[0, [('new_relu', 0)]]]
NodeType.CallCell new_relu x = self.new_relu(x) [[0, ('dense', 0)]] [[0, [('relu', 0)]]]
NodeType.CallCell relu x = self.relu(x) [[0, ('new_relu', 0)]] [[0, [('return', 0)]]]
NodeType.Output return return x [[0, ('relu', 0)]] []
==================================================================================
```
It can be seen that the new new_relu node is inserted between the dense node and the relu node, and the topology of
node is automatically updated with the node insertion.
The definition of `self.new_relu` in the code of new node is saved in the init function of the new network, using
parameter `new_relu_cell` as the instance.
In addition to getting nodes using [mindspore.rewrite.SymbolTree.get_node](https://mindspore.cn/docs/en/master/api_python/mindspore.rewrite.html#mindspore.rewrite.SymbolTree.get_node) to specify the insertion location, we can
also iterate through nodes by [mindspore.rewrite.SymbolTree.nodes](https://mindspore.cn/docs/en/master/api_python/mindspore.rewrite.html#mindspore.rewrite.SymbolTree.nodes) and use [mindspore.rewrite.Node.get_instance_type](https://mindspore.cn/docs/en/master/api_python/mindspore.rewrite.html#mindspore.rewrite.Node.get_instance_type)
to get the node and determine the insertion position based on the type of corresponding instance of node.
``` python
for node in stree.nodes():
if node.get_instance_type() == nn.Dense:
stree.insert(stree.after(node), new_node)
```
If we want the output of new code to be inserted does not reuse variables from the original network, we can
use [mindspore.rewrite.SymbolTree.unique_name](https://mindspore.cn/docs/en/master/api_python/mindspore.rewrite.html#mindspore.rewrite.SymbolTree.unique_name) to get an variable name that are not duplicated in the SymbolTree
as the output of node when creating nodes.
Then, before inserting the node, we can modify the node input variable name by using [mindspore.rewrite.Node.set_arg](https://mindspore.cn/docs/en/master/api_python/mindspore.rewrite.html#mindspore.rewrite.Node.set_arg)
to set which nodes use the new node output as input.
``` python
from mindspore.rewrite import SymbolTree, Node, ScopedValue
net = MyNet()
stree = SymbolTree.create(net)
new_relu_cell = nn.ReLU()
new_node = Node.create_call_cell(cell=new_relu_cell, targets=[stree.unique_name("x")],
args=[ScopedValue.create_naming_value("x")], name="new_relu")
dense_node = stree.get_node("dense")
stree.insert(stree.after(dense_node), new_node)
old_relu_node = stree.get_node("relu")
old_relu_node.set_arg(0, new_node.get_targets()[0])
stree.print_node_tabulate()
```
In this example, when creating a new node, the value of the `targets` parameter is treated without duplication,
and the input of old relu node is changed to the output of new node.
The results are as follows:
``` log
================================================================================
node type name codes arg providers target users
----------------- -------- ---------------------- ---------------------- ------------------------
NodeType.Input input_x x [] [[0, [('dense', 0)]]]
NodeType.CallCell dense x = self.dense(x) [[0, ('input_x', 0)]] [[0, [('new_relu', 0)]]]
NodeType.CallCell new_relu x_1 = self.new_relu(x) [[0, ('dense', 0)]] [[0, [('relu', 0)]]]
NodeType.CallCell relu x = self.relu(x_1) [[0, ('new_relu', 0)]] [[0, [('return', 0)]]]
NodeType.Output return return x [[0, ('relu', 0)]] []
==================================================================================
```
It can be seen that the output variable name of new node is an unnamed name ``x_1`` , and the old relu node uses ``x_1`` as input.
## Deleting Nodes
When we need to delete a line of code during the forward computation of the network, we can use the interface
[mindspore.rewrite.SymbolTree.erase](https://mindspore.cn/docs/en/master/api_python/mindspore.rewrite.html#mindspore.rewrite.SymbolTree.erase) to delete the node.
After the node is deleted, the topological relationship of the remaining nodes in the symbol tree will be automatically
updated according to the code of network after deletion.
Therefore, when the output of node to be deleted is used by other nodes, we need to pay attention to whether the topological
relationship of the remaining nodes meets the design expectations after the node is deleted.
If a node exists in front of the node to be deleted that has the same output name as the node to be deleted, after the node
is deleted, the output of the previous node is automatically used as input for the node that uses the output name as the input.
The topology relationship is updated according to this policy.
``` python
from mindspore.rewrite import SymbolTree, Node, ScopedValue
net = MyNet()
stree = SymbolTree.create(net)
relu_node = stree.get_node("relu")
stree.erase(relu_node)
stree.print_node_tabulate()
```
The results are as follows:
``` log
================================================================================
node type name codes arg providers target users
----------------- ------- ----------------- --------------------- ----------------------
NodeType.Input input_x x [] [[0, [('dense', 0)]]]
NodeType.CallCell dense x = self.dense(x) [[0, ('input_x', 0)]] [[0, [('return', 0)]]]
NodeType.Output return return x [[0, ('dense', 0)]] []
==================================================================================
```
It can be seen that because the output of dense node and the output of relu node have the same name, after deleting
the relu node, the return value uses the output of the dense node.
If there is no node that has the same output name as the node to be deleted in front of the node to be deleted, we need
to modify subsequent nodes that uses this output as input by updating the input names, and then delete the node, in order
to avoid errors using undefined variables after deleting the node.
``` python
import mindspore.nn as nn
from mindspore.rewrite import SymbolTree
class MyNet_3(nn.Cell):
def __init__(self):
super().__init__()
self.dense = nn.Dense(in_channels=32, out_channels=32, has_bias=False, weight_init="ones")
self.relu = nn.ReLU()
def construct(self, x):
y = self.dense(x)
z = self.relu(y)
return z
net = MyNet_3()
stree = SymbolTree.create(net)
relu_node = stree.get_node("relu")
for node in relu_node.get_users():
node.set_arg(0, relu_node.get_args()[0])
stree.erase(relu_node)
stree.print_node_tabulate()
```
In this example, after getting the relu node, first we use the interface [mindspore.rewrite.Node.get_users](https://mindspore.cn/docs/en/master/api_python/mindspore.rewrite.html#mindspore.rewrite.Node.get_users) to
iterate through the nodes that use the output of relu node as input, change the input of these nodes to the input of relu
node, and then delete the relu node. In this case, the subsequent use of the relu node output ``z`` will be changed to
the relu node input ``y`` .
The specific parameter name modification strategy depends on the actual scenario requirements.
The results are as follows:
``` log
================================================================================
node type name codes arg providers target users
----------------- ------- ----------------- --------------------- ----------------------
NodeType.Input input_x x [] [[0, [('dense', 0)]]]
NodeType.CallCell dense y = self.dense(x) [[0, ('input_x', 0)]] [[0, [('return', 0)]]]
NodeType.Output return return y [[0, ('dense', 0)]] []
==================================================================================
```
It can be seen that after deleting the relu node, the value of the last return node is updated from ``z`` to ``y`` .
## Replacing Nodes
When we need to replace code during the forward computation of network, we can replace the node with the
interface [mindspore.rewrite.SymbolTree.replace](https://mindspore.cn/docs/en/master/api_python/mindspore.rewrite.html#mindspore.rewrite.SymbolTree.replace) .
``` python
from mindspore.rewrite import SymbolTree, Node, ScopedValue
net = MyNet()
stree = SymbolTree.create(net)
new_relu_cell = nn.ReLU()
new_node = Node.create_call_cell(cell=new_relu_cell, targets=["x"],
args=[ScopedValue.create_naming_value("x")], name="new_relu")
relu_node = stree.get_node("relu")
stree.replace(relu_node, [new_node])
stree.print_node_tabulate()
```
This example replaces relu node in the original network with new_relu node. The results are as follows:
``` log
================================================================================
node type name codes arg providers target users
----------------- -------- -------------------- ---------------------- ------------------------
NodeType.Input input_x x [] [[0, [('dense', 0)]]]
NodeType.CallCell dense x = self.dense(x) [[0, ('input_x', 0)]] [[0, [('new_relu', 0)]]]
NodeType.CallCell new_relu x = self.new_relu(x) [[0, ('dense', 0)]] [[0, [('return', 0)]]]
NodeType.Output return return x [[0, ('new_relu', 0)]] []
==================================================================================
```
If the output name of the new node and the replaced node are inconsistent, we need to pay attention
to maintaining the topological relationship between nodes after replacement, that is, first modify the subsequent nodes that
uses the output of the replaced node, update the parameter names of these nodes, and then perform the node replacement operation.
``` python
from mindspore.rewrite import SymbolTree, Node, ScopedValue
net = MyNet()
stree = SymbolTree.create(net)
# Update the parameter names of subsequent nodes
relu_node = stree.get_node("relu")
for node in relu_node.get_users():
node.set_arg(0, "y1")
# Create two new nodes
new_relu_cell = nn.ReLU()
new_node = Node.create_call_cell(cell=new_relu_cell, targets=["y1"],
args=[ScopedValue.create_naming_value("x")], name="new_relu_1")
new_relu_cell_2 = nn.ReLU()
new_node_2 = Node.create_call_cell(cell=new_relu_cell_2, targets=["y2"],
args=[ScopedValue.create_naming_value("x")], name="new_relu_2")
# Replace relu node with two new nodes
stree.replace(relu_node, [new_node, new_node_2])
stree.print_node_tabulate()
```
The example replaces relu node with two new nodes, where the output of first node ``y1`` is used as the return value in the
return node. The results are as follows:
``` log
================================================================================
node type name codes arg providers target users
----------------- ---------- ----------------------- ---------------------- -------------------------------------------
NodeType.Input input_x x [] [[0, [('dense', 0)]]]
NodeType.CallCell dense x = self.dense(x) [[0, ('input_x', 0)]] [[0, [('new_relu', 0), ('new_relu_1', 0)]]]
NodeType.CallCell new_relu y1 = self.new_relu(x) [[0, ('dense', 0)]] [[0, [('return', 0)]]]
NodeType.CallCell new_relu_1 y2 = self.new_relu_1(x) [[0, ('dense', 0)]] []
NodeType.Output return return y1 [[0, ('new_relu', 0)]] []
==================================================================================
```
It can be seen that the relu node was successfully replaced with two new nodes, and the return value was also
updated to the output of the first new node.
## Returning A New Network
When the network is modified, we can use the interface [mindspore.rewrite.SymbolTree.get_network](https://mindspore.cn/docs/en/master/api_python/mindspore.rewrite.html#mindspore.rewrite.SymbolTree.get_network) to get the
modified network instance.
``` python
from mindspore import Tensor
from mindspore.common import dtype as mstype
import numpy as np
new_net = stree.get_network()
inputs = Tensor(np.ones([1, 1, 32, 32]), mstype.float32)
outputs = new_net(inputs)
```
After calling this interface, rewrite module will first generate a script file corresponding to the modified network in the
rewritten_network folder of the current working directory, and then use the script file to create a new network instance,
and use the original network instance as a parameter. New network instances can be used directly for compute and training.

View File

@ -236,6 +236,12 @@ class Node:
belong_symbol_tree.set_node_arg_by_node(self._node, arg_idx, src_node.get_handler(), out_idx)
def get_targets(self) -> [ScopedValue]:
"""
Gets a list of output values for the current node.
Returns:
A list of outputs of type ``ScopedValue`` .
"""
return self._node.get_targets()
def get_name(self) -> str:

View File

@ -389,6 +389,9 @@ class SymbolTree:
Returns:
str, A new, unique name within a symbol tree in the format `name_n`, where `n` is a numeric subscript.
If there is no name conflict when entered `name`, there is no numeric subscript.
Raises:
TypeError: The type of `name` is not str.
"""
Validator.check_value_type("name", name, [str], "SymbolTree")
return self._symbol_tree.unique_name(name)