del rewrite api
This commit is contained in:
parent
f292133c39
commit
ca0df4f001
|
@ -54,26 +54,6 @@ MindSpore的ReWrite功能用于修改网络前向的计算过程,对网络进
|
|||
异常:
|
||||
- **TypeError** - 参数 `network` 不是Cell类型对象。
|
||||
|
||||
.. py:method:: mindspore.rewrite.SymbolTree.create_call_function(func, targets, args, kwargs)
|
||||
|
||||
创建一个Node对象,并生成执行代码插入源码中。源码中以 `args` 和 `kwargs` 为参数调用 `func` 函数。
|
||||
|
||||
参数:
|
||||
- **func** (FunctionType) - 要被调用的函数。
|
||||
- **targets** (list[str]) - 表示输出名称。在源代码中作为节点的输出。
|
||||
- **args** (Union[MsDtypes, ParamTypes]) - 该节点的参数名称。用作源代码中代码语句的参数。默认为None表示 `cell` 没有参数输入。
|
||||
- **kwargs** (dict{str,Union[MsDtypes, ParamTypes]}) - 键的类型必须是str,值必须是MsDtypes或类型必须是ParamTypes。用来说明带有关键字的形参的输入参数名称。输入名称在源代码中作为语句表达式中的 `kwargs`。默认为None,表示没有 `kwargs` 输入。
|
||||
|
||||
返回:
|
||||
一个Node实例。
|
||||
|
||||
异常:
|
||||
- **TypeError** - 如果参数 `func` 不是FunctionType类型。
|
||||
- **TypeError** - 如果参数 `targets` 不是list类型。
|
||||
- **TypeError** - 如果参数 `targets` 的成员不是str类型。
|
||||
- **TypeError** - 如果参数 `args` 不是ParamType类型。
|
||||
- **TypeError** - 如果参数 `kwarg` 的 `key` 不是str类型或者 `value` 不是ParamType类型。
|
||||
|
||||
.. py:method:: mindspore.rewrite.SymbolTree.dump()
|
||||
|
||||
将 `SymbolTree` 中network对应的ir图信息打印到屏幕。
|
||||
|
@ -98,13 +78,6 @@ MindSpore的ReWrite功能用于修改网络前向的计算过程,对网络进
|
|||
返回:
|
||||
str,SymbolTree对应的源码字符串。
|
||||
|
||||
.. py:method:: mindspore.rewrite.SymbolTree.get_handler()
|
||||
|
||||
获取SymbolTree对应实现的handle。
|
||||
|
||||
返回:
|
||||
SymbolTree对象。
|
||||
|
||||
.. py:method:: mindspore.rewrite.SymbolTree.get_network()
|
||||
|
||||
获取SymbolTree所对应的生成的网络对象。源码会保存到文件中,默认的文件名为 `network_define.py`。
|
||||
|
@ -112,19 +85,6 @@ MindSpore的ReWrite功能用于修改网络前向的计算过程,对网络进
|
|||
返回:
|
||||
根据SymbolTree生成的网络对象。
|
||||
|
||||
.. py:method:: mindspore.rewrite.SymbolTree.get_node(node_name: str)
|
||||
|
||||
获取节点名为 `node_name` 的节点。
|
||||
|
||||
参数:
|
||||
- **node_name** (str) - 节点的名称。
|
||||
|
||||
返回:
|
||||
如果找到则返回结果,否则返回 `None`。
|
||||
|
||||
异常:
|
||||
- **TypeError** - 如果 `node_name` 不是Node类型。
|
||||
|
||||
.. py:method:: mindspore.rewrite.SymbolTree.insert(position, node: Node)
|
||||
|
||||
在SymbolTree的 `position` 位置插入一个节点。 `position` 可以通过 `before` 或 `after` 来获得。
|
||||
|
@ -148,10 +108,6 @@ MindSpore的ReWrite功能用于修改网络前向的计算过程,对网络进
|
|||
返回:
|
||||
当前SymbolTree中节点的生成器。
|
||||
|
||||
.. py:method:: mindspore.rewrite.SymbolTree.print_node_tabulate()
|
||||
|
||||
打印当前SymbolTree的节点信息表格。
|
||||
|
||||
.. py:method:: mindspore.rewrite.SymbolTree.replace(old_node: Node, new_nodes: [Node])
|
||||
|
||||
使用新节点列表来替代旧节点。
|
||||
|
@ -209,45 +165,6 @@ MindSpore的ReWrite功能用于修改网络前向的计算过程,对网络进
|
|||
- **TypeError** - 如果参数 `args` 不是ScopedValue类型。
|
||||
- **TypeError** - 如果参数 `kwarg` 的 `key` 不是str类型或者 `value` 不是ScopedValue类型。
|
||||
|
||||
.. py:method:: mindspore.rewrite.Node.get_args()
|
||||
|
||||
获取当前节点的参数。
|
||||
|
||||
- 当前节点的 `node_type` 为 `CallCell`、 `CallPrimitive` 或 `Tree` 时,返回值对应于 ast.Call 的 `args`,表示调用 `cell-op` 或 `primitive-op` 的 `forward` 方法的参数。
|
||||
- 当前节点的 `node_type` 为 `Input` 时,返回值为函数参数的默认值。
|
||||
- 当前节点的 `node_type` 为 `Output` 时,返回值为网络的返回值。
|
||||
- 当前节点的 `node_type` 为 `Python` 时,没有实际含义,可以忽略。
|
||||
|
||||
返回:
|
||||
`ScopedValue` 实例的列表。
|
||||
|
||||
.. py:method:: mindspore.rewrite.Node.get_attribute(key: str)
|
||||
|
||||
获取当前节点属性 `key` 的值。
|
||||
|
||||
参数:
|
||||
- **key** (str) - 属性的名称。
|
||||
|
||||
返回:
|
||||
属性值,可能是任意类型。
|
||||
|
||||
异常:
|
||||
- **TypeError** - 如果参数 `key` 不是str类型。
|
||||
|
||||
.. py:method:: mindspore.rewrite.Node.get_attributes()
|
||||
|
||||
获取当前节点的所有属性。
|
||||
|
||||
返回:
|
||||
返回一个包含属性名和属性值的字典。
|
||||
|
||||
.. py:method:: mindspore.rewrite.Node.get_handler()
|
||||
|
||||
获取节点具体实现的handle。
|
||||
|
||||
返回:
|
||||
返回NodeImpl的实例。
|
||||
|
||||
.. py:method:: mindspore.rewrite.Node.get_inputs()
|
||||
|
||||
获取当前节点的拓扑序的输入节点。
|
||||
|
@ -255,18 +172,6 @@ MindSpore的ReWrite功能用于修改网络前向的计算过程,对网络进
|
|||
返回:
|
||||
Node的实例列表。
|
||||
|
||||
.. py:method:: mindspore.rewrite.Node.get_instance()
|
||||
|
||||
获取当前节点对应的 `operation` 实例。
|
||||
|
||||
- 如果当前节点的 `node_type` 是 `CallCell`,该节点的实例是一个Cell的对象。
|
||||
- 如果当前节点的 `node_type` 是 `CallPrimitive`,该节点的实例是一个Primitive的对象。
|
||||
- 如果当前节点的 `node_type` 是 `Tree`,该节点的实例是一个网络的对象。
|
||||
- 如果当前节点的 `node_type` 是 `Python`、 `Input`、 `Output`、 `CallMethod`,该节点的实例为None。
|
||||
|
||||
返回:
|
||||
当前节点的 `operation` 实例。
|
||||
|
||||
.. py:method:: mindspore.rewrite.Node.get_instance_type()
|
||||
|
||||
获取当前节点对应的 `operation` 实例类型。
|
||||
|
@ -279,16 +184,6 @@ MindSpore的ReWrite功能用于修改网络前向的计算过程,对网络进
|
|||
返回:
|
||||
当前节点的 `operation` 类型。
|
||||
|
||||
.. py:method:: mindspore.rewrite.Node.get_kwargs()
|
||||
|
||||
获取当前节点带 `key` 值的参数。
|
||||
|
||||
- 当前节点的 `node_type` 为 `CallCell`、 `CallPrimitive` 或 `Tree` 时,关键字参数对应于 `ast.Call` 的 `kwargs`,表示调用 `cell-op` 或 `Primitive-op` 方法的参数。
|
||||
- 当前节点的 `node_type` 为 `Python`、 `Input` 或 `Output` 时,不关心关键字参数。
|
||||
|
||||
返回:
|
||||
`key` 为str, `value` 为ScopedValue的字典。
|
||||
|
||||
.. py:method:: mindspore.rewrite.Node.get_name()
|
||||
|
||||
获取当前节点的名称。当节点被插入到SymbolTree时,节点的名称在SymbolTree中应该是唯一的。
|
||||
|
@ -303,17 +198,6 @@ MindSpore的ReWrite功能用于修改网络前向的计算过程,对网络进
|
|||
返回:
|
||||
NodeType,当前节点的类型。
|
||||
|
||||
.. py:method:: mindspore.rewrite.Node.get_targets()
|
||||
|
||||
获取当前节点的输出名称。
|
||||
|
||||
- 当前节点的 `node_type` 为 `CallCell`、 `CallPrimitive`、 `CallMethod` 或 `Tree` 时, `target` 为字符串,表示单元操作或原始操作或函数调用的调用结果,它们对应于 `ast.Assign` 的 `targets`。
|
||||
- 当前节点的 `node_type` 为 `Input` 时, `targets` 应该只有一个元素,字符串代表函数的参数。
|
||||
- 当前节点的 `node_type` 为 `Python` 或 `Output` 时, `target` 不需要关心。
|
||||
|
||||
返回:
|
||||
节点输出的ScopedValue列表。
|
||||
|
||||
.. py:method:: mindspore.rewrite.Node.get_users()
|
||||
|
||||
按拓扑顺序获取当前节点的输出节点。
|
||||
|
@ -352,17 +236,6 @@ MindSpore的ReWrite功能用于修改网络前向的计算过程,对网络进
|
|||
- **ValueError** - 如果参数 `out_idx` 超出了 `src_node` 的输出数量。
|
||||
- **ValueError** - 当 `out_idx` 为None或者没有给 `out_idx` 赋值时,参数 `src_node` 有多个输出。
|
||||
|
||||
.. py:method:: mindspore.rewrite.Node.set_attribute(key: str, value)
|
||||
|
||||
设置当前节点的属性。
|
||||
|
||||
参数:
|
||||
- **key** (str) - 属性的名称。
|
||||
- **value** (object) - 属性值。
|
||||
|
||||
异常:
|
||||
- **TypeError** - 如果参数 `key` 不是str类型。
|
||||
|
||||
.. py:class:: mindspore.rewrite.NodeType
|
||||
|
||||
NodeType表示Node的类型。
|
||||
|
|
|
@ -0,0 +1,195 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
This example mainly illustrates the usage of rewrite.
|
||||
"""
|
||||
from typing import OrderedDict
|
||||
import numpy as np
|
||||
|
||||
import mindspore
|
||||
from mindspore import Tensor, export
|
||||
from mindspore.rewrite import SymbolTree, ScopedValue, Node, NodeType, Replacement, PatternEngine, PatternNode, \
|
||||
TreeNodeHelper
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
|
||||
|
||||
class SubNet(nn.Cell):
|
||||
"""子网络定义"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.dense = nn.Dense(in_channels=32, out_channels=32, weight_init="ones")
|
||||
self.mean = ops.ReduceMean(keep_dims=False)
|
||||
self.conv1 = nn.Conv2d(1, 1, 1, stride=1)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.dense(x)
|
||||
return x
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
"""网络定义"""
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.conv1 = nn.Conv2d(1, 1, 1, pad_mode='valid')
|
||||
self.conv2 = nn.Conv2d(1, 1, 1, pad_mode='valid')
|
||||
self.relu = nn.ReLU()
|
||||
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.flatten = nn.Flatten()
|
||||
self.simnet = SubNet()
|
||||
|
||||
def construct(self, x):
|
||||
"""网络的前向计算过程"""
|
||||
x = self.conv1(x)
|
||||
x = self.relu(x)
|
||||
x = self.max_pool2d(x)
|
||||
x = self.conv2(x)
|
||||
x = self.relu(x)
|
||||
x = self.max_pool2d(x)
|
||||
x = self.simnet(x)
|
||||
x = self.flatten(x)
|
||||
return x
|
||||
|
||||
|
||||
def create_stree(network):
|
||||
"""创建SymbolTree"""
|
||||
stree = SymbolTree.create(network)
|
||||
stree.dump()
|
||||
return stree
|
||||
|
||||
|
||||
def insert_node(stree):
|
||||
"""在网络中插入节点"""
|
||||
for node in stree.nodes():
|
||||
if node.get_name() == "conv2": # 在名称为'conv2'的节点前面插入新的节点
|
||||
position = stree.before(node)
|
||||
new_conv = nn.Conv2d(1, 1, 1)
|
||||
new_conv_node = Node.create_call_cell(new_conv, targets=['x_1'], name='new_conv',
|
||||
args=node.get_args())
|
||||
stree.insert(position, new_conv_node)
|
||||
break
|
||||
# 使用新节点更新已有节点的输入
|
||||
if new_conv_node is not None:
|
||||
for node in stree.nodes():
|
||||
if node.get_name() == "relu_1":
|
||||
node.set_arg_by_node(0, new_conv_node)
|
||||
break
|
||||
|
||||
|
||||
def insert_node_to_subtree(stree):
|
||||
"""在子网络中插入节点"""
|
||||
def _insert_conv(stree: SymbolTree):
|
||||
for node in stree.nodes():
|
||||
if node.get_instance_type() == nn.Conv2d:
|
||||
position = stree.after(node)
|
||||
new_conv = nn.Conv2d(1, 1, 1)
|
||||
new_conv_node = Node.create_call_cell(new_conv, targets=['x_1'], name='new_conv',
|
||||
args=[ScopedValue.create_naming_value('x_1')])
|
||||
stree.insert(position, new_conv_node)
|
||||
break
|
||||
# 在名称为'simnet'的子网络中插入新节点
|
||||
for node in stree.nodes():
|
||||
if node.get_node_type() == NodeType.Tree and node.get_name() == "simnet":
|
||||
_insert_conv(TreeNodeHelper.get_sub_tree(node))
|
||||
break
|
||||
|
||||
|
||||
def delete_node(stree):
|
||||
"""删除类型为nn.Flatten的节点"""
|
||||
for node in stree.nodes():
|
||||
if node.get_instance_type() == nn.Flatten:
|
||||
for n in node.get_users():
|
||||
n.set_arg(0, "x_7")
|
||||
stree.erase_node(node)
|
||||
break
|
||||
|
||||
|
||||
def replace_node(stree):
|
||||
"""替换网络中的节点"""
|
||||
new_conv = nn.Conv2d(1, 1, 1)
|
||||
new_conv_node = Node.create_call_cell(new_conv, [ScopedValue.create_naming_value("replace_conv")],
|
||||
args=[ScopedValue.create_naming_value('x')])
|
||||
for node in stree.nodes():
|
||||
if node.get_name() == "conv1":
|
||||
new_conv_node = stree.replace(node, [new_conv_node])
|
||||
|
||||
|
||||
def pattern_replace(stree):
|
||||
"""通过模式匹配的方式替换节点"""
|
||||
class ConvReplacement(Replacement):
|
||||
"""创建新节点类的实现"""
|
||||
def build(self, pattern: PatternNode, is_chain_pattern: bool, matched: OrderedDict) -> [Node]:
|
||||
assert is_chain_pattern
|
||||
assert pattern.type() == nn.MaxPool2d
|
||||
bn_node: Node = matched.get(pattern.name())
|
||||
assert bn_node is not None
|
||||
|
||||
conv = nn.Conv2d(1, 1, 1)
|
||||
conv_node = Node.create_call_cell(conv, ['x1'], bn_node.get_args(), bn_node.get_kwargs(),
|
||||
name="pattern_conv")
|
||||
return [conv_node]
|
||||
|
||||
class BnReplace(PatternEngine):
|
||||
# 替换网络中nn.MaxPool2d类型的节点
|
||||
def __init__(self):
|
||||
super().__init__([nn.MaxPool2d], ConvReplacement())
|
||||
|
||||
bn_replace = BnReplace()
|
||||
bn_replace.apply(stree)
|
||||
|
||||
|
||||
def get_net(stree):
|
||||
"""获取修改后的网络"""
|
||||
return stree.get_network()
|
||||
|
||||
|
||||
def get_code(stree):
|
||||
"""获取修改后的网络代码"""
|
||||
return stree.get_code()
|
||||
|
||||
|
||||
def test_rewrite():
|
||||
"""ReWrite测试函数"""
|
||||
net = Net()
|
||||
stree = create_stree(net)
|
||||
|
||||
print(f"origin code: {stree.get_code()}")
|
||||
insert_node(stree)
|
||||
print(f"after inser node code: {stree.get_code()}")
|
||||
|
||||
insert_node_to_subtree(stree)
|
||||
print(f"after inser node to subtree code: {stree.get_code()}")
|
||||
|
||||
delete_node(stree)
|
||||
print(f"after remove node code: {stree.get_code()}")
|
||||
|
||||
replace_node(stree)
|
||||
print(f"after replace node code: {stree.get_code()}")
|
||||
|
||||
pattern_replace(stree)
|
||||
print(f"after pattern replace node code: {stree.get_code()}")
|
||||
|
||||
inputs = Tensor(np.ones([1, 1, 32, 32]), mindspore.float32) # pylint: disable=E1102
|
||||
new_net = get_net(stree)
|
||||
source_code = get_code(stree)
|
||||
print(source_code)
|
||||
out = new_net(inputs)
|
||||
print("out: ", out)
|
||||
export(new_net, inputs, file_name="new_net", file_format="MINDIR")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_rewrite()
|
|
@ -0,0 +1,195 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
This example mainly illustrates the usage of rewrite.
|
||||
"""
|
||||
from typing import OrderedDict
|
||||
import numpy as np
|
||||
|
||||
import mindspore
|
||||
from mindspore import Tensor, export
|
||||
from mindspore.rewrite import SymbolTree, ScopedValue, Node, NodeType, Replacement, PatternEngine, PatternNode, \
|
||||
TreeNodeHelper
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
|
||||
|
||||
class SubNet(nn.Cell):
|
||||
"""Subnetwork definition"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.dense = nn.Dense(in_channels=32, out_channels=32, weight_init="ones")
|
||||
self.mean = ops.ReduceMean(keep_dims=False)
|
||||
self.conv1 = nn.Conv2d(1, 1, 1, stride=1)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.dense(x)
|
||||
return x
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
"""Network definition"""
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.conv1 = nn.Conv2d(1, 1, 1, pad_mode='valid')
|
||||
self.conv2 = nn.Conv2d(1, 1, 1, pad_mode='valid')
|
||||
self.relu = nn.ReLU()
|
||||
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.flatten = nn.Flatten()
|
||||
self.simnet = SubNet()
|
||||
|
||||
def construct(self, x):
|
||||
"""The forward computing process of networks."""
|
||||
x = self.conv1(x)
|
||||
x = self.relu(x)
|
||||
x = self.max_pool2d(x)
|
||||
x = self.conv2(x)
|
||||
x = self.relu(x)
|
||||
x = self.max_pool2d(x)
|
||||
x = self.simnet(x)
|
||||
x = self.flatten(x)
|
||||
return x
|
||||
|
||||
|
||||
def create_stree(network):
|
||||
"""Create SymbolTree"""
|
||||
stree = SymbolTree.create(network)
|
||||
stree.dump()
|
||||
return stree
|
||||
|
||||
|
||||
def insert_node(stree):
|
||||
"""Insert a node into the network"""
|
||||
for node in stree.nodes():
|
||||
if node.get_name() == "conv2": # Insert a new node before the node named 'conv2'
|
||||
position = stree.before(node)
|
||||
new_conv = nn.Conv2d(1, 1, 1)
|
||||
new_conv_node = Node.create_call_cell(new_conv, targets=['x_1'], name='new_conv',
|
||||
args=node.get_args())
|
||||
stree.insert(position, new_conv_node)
|
||||
break
|
||||
# Update the input of an existing node with a new node
|
||||
if new_conv_node is not None:
|
||||
for node in stree.nodes():
|
||||
if node.get_name() == "relu_1":
|
||||
node.set_arg_by_node(0, new_conv_node)
|
||||
break
|
||||
|
||||
|
||||
def insert_node_to_subtree(stree):
|
||||
"""Inserting a node into a subnetwork"""
|
||||
def _insert_conv(stree: SymbolTree):
|
||||
for node in stree.nodes():
|
||||
if node.get_instance_type() == nn.Conv2d:
|
||||
position = stree.after(node)
|
||||
new_conv = nn.Conv2d(1, 1, 1)
|
||||
new_conv_node = Node.create_call_cell(new_conv, targets=['x_1'], name='new_conv',
|
||||
args=[ScopedValue.create_naming_value('x_1')])
|
||||
stree.insert(position, new_conv_node)
|
||||
break
|
||||
# Insert a new node in the subnet named 'simnet'
|
||||
for node in stree.nodes():
|
||||
if node.get_node_type() == NodeType.Tree and node.get_name() == "simnet":
|
||||
_insert_conv(TreeNodeHelper.get_sub_tree(node))
|
||||
break
|
||||
|
||||
|
||||
def delete_node(stree):
|
||||
"""Delete nodes of type nn.Flatten"""
|
||||
for node in stree.nodes():
|
||||
if node.get_instance_type() == nn.Flatten:
|
||||
for n in node.get_users():
|
||||
n.set_arg(0, "x_7")
|
||||
stree.erase_node(node)
|
||||
break
|
||||
|
||||
|
||||
def replace_node(stree):
|
||||
"""Replace nodes in the network"""
|
||||
new_conv = nn.Conv2d(1, 1, 1)
|
||||
new_conv_node = Node.create_call_cell(new_conv, [ScopedValue.create_naming_value("replace_conv")],
|
||||
args=[ScopedValue.create_naming_value('x')])
|
||||
for node in stree.nodes():
|
||||
if node.get_name() == "conv1":
|
||||
new_conv_node = stree.replace(node, [new_conv_node])
|
||||
|
||||
|
||||
def pattern_replace(stree):
|
||||
"""Replace nodes by pattern matching"""
|
||||
class ConvReplacement(Replacement):
|
||||
"""Create the implementation of a new node class."""
|
||||
def build(self, pattern: PatternNode, is_chain_pattern: bool, matched: OrderedDict) -> [Node]:
|
||||
assert is_chain_pattern
|
||||
assert pattern.type() == nn.MaxPool2d
|
||||
bn_node: Node = matched.get(pattern.name())
|
||||
assert bn_node is not None
|
||||
|
||||
conv = nn.Conv2d(1, 1, 1)
|
||||
conv_node = Node.create_call_cell(conv, ['x1'], bn_node.get_args(), bn_node.get_kwargs(),
|
||||
name="pattern_conv")
|
||||
return [conv_node]
|
||||
|
||||
class BnReplace(PatternEngine):
|
||||
# Replace node of type nn.MaxPool2d in the network
|
||||
def __init__(self):
|
||||
super().__init__([nn.MaxPool2d], ConvReplacement())
|
||||
|
||||
bn_replace = BnReplace()
|
||||
bn_replace.apply(stree)
|
||||
|
||||
|
||||
def get_net(stree):
|
||||
"""Get the modified network"""
|
||||
return stree.get_network()
|
||||
|
||||
|
||||
def get_code(stree):
|
||||
"""Get the modified network code"""
|
||||
return stree.get_code()
|
||||
|
||||
|
||||
def test_rewrite():
|
||||
"""ReWrite test function."""
|
||||
net = Net()
|
||||
stree = create_stree(net)
|
||||
|
||||
print(f"origin code: {stree.get_code()}")
|
||||
insert_node(stree)
|
||||
print(f"after inser node code: {stree.get_code()}")
|
||||
|
||||
insert_node_to_subtree(stree)
|
||||
print(f"after inser node to subtree code: {stree.get_code()}")
|
||||
|
||||
delete_node(stree)
|
||||
print(f"after remove node code: {stree.get_code()}")
|
||||
|
||||
replace_node(stree)
|
||||
print(f"after replace node code: {stree.get_code()}")
|
||||
|
||||
pattern_replace(stree)
|
||||
print(f"after pattern replace node code: {stree.get_code()}")
|
||||
|
||||
inputs = Tensor(np.ones([1, 1, 32, 32]), mindspore.float32)
|
||||
new_net = get_net(stree)
|
||||
source_code = get_code(stree)
|
||||
print(source_code)
|
||||
out = new_net(inputs)
|
||||
print("out: ", out)
|
||||
export(new_net, inputs, file_name="new_net", file_format="MINDIR")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_rewrite()
|
|
@ -99,12 +99,6 @@ class Node:
|
|||
args, kwargs, name, is_sub_net))
|
||||
|
||||
def get_handler(self) -> NodeImpl:
|
||||
"""
|
||||
Get handler of node implementation.
|
||||
|
||||
Returns:
|
||||
An instance of `NodeImpl`.
|
||||
"""
|
||||
return self._node
|
||||
|
||||
def get_inputs(self) -> ['Node']:
|
||||
|
@ -209,27 +203,6 @@ class Node:
|
|||
belong_symbol_tree.set_node_arg_by_node(self._node, arg_idx, src_node.get_handler(), out_idx)
|
||||
|
||||
def get_targets(self) -> [ScopedValue]:
|
||||
"""
|
||||
Get targets of current node.
|
||||
|
||||
- When node_type of current node is `CallCell`, `CallPrimitive`, `CallMethod` or `Tree`, `targets` are strings
|
||||
represents invoke result of the cell-op or primitive-op or function-call which are corresponding to targets of
|
||||
ast.Assign.
|
||||
- When node_type of current node is Input, `targets` should have only one element which is a string represents
|
||||
parameter of function.
|
||||
- When node_type of current node is `Python` or `Output`, `targets` are don't-care.
|
||||
|
||||
Returns:
|
||||
A list of instances of ScopedValue as targets of node.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.rewrite import SymbolTree
|
||||
>>> from lenet import Lenet
|
||||
>>> net = Lenet()
|
||||
>>> stree = SymbolTree.create(net)
|
||||
>>> node = stree.get_node("conv1")
|
||||
>>> targets = node.get_targets()
|
||||
"""
|
||||
return self._node.get_targets()
|
||||
|
||||
def get_name(self) -> str:
|
||||
|
@ -284,106 +257,21 @@ class Node:
|
|||
return self._node.get_instance_type()
|
||||
|
||||
def get_instance(self):
|
||||
"""
|
||||
Get the instance of current node.
|
||||
|
||||
- When node_type of current node is `CallCell`, instance is an instance of Cell.
|
||||
- When node_type of current node is `CallPrimitive`, instance is an instance of primitive.
|
||||
- When node_type of current node is `Tree`, instance is an instance of network-cell.
|
||||
- When node_type of current node is `Python`, `Input`, `Output` or `CallMethod`, instance should be None.
|
||||
|
||||
Returns:
|
||||
A object represents corresponding instance of current node.
|
||||
"""
|
||||
return self._node.get_instance()
|
||||
|
||||
def get_args(self) -> [ScopedValue]:
|
||||
"""
|
||||
Get the arguments of current node.
|
||||
|
||||
- When `node_type` of current node is `CallCell`, `CallPrimitive` or `Tree`, arguments are corresponding to args
|
||||
of ast.Call which represents arguments to invoke forward method of cell-op or primitive-op.
|
||||
- When `node_type` of current node is `Input`, arguments represents default-value of argument of function.
|
||||
- When `node_type` of current node is `Output`, arguments represents the return values of network.
|
||||
- When `node_type` of current node is `Python`, arguments are don't-care.
|
||||
|
||||
Returns:
|
||||
A list of instances of `ScopedValue`.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.rewrite import SymbolTree
|
||||
>>> from lenet import Lenet
|
||||
>>> net = Lenet()
|
||||
>>> stree = SymbolTree.create(net)
|
||||
>>> node = stree.get_node("conv1")
|
||||
>>> args = node.get_args()
|
||||
"""
|
||||
return self._node.get_args()
|
||||
|
||||
def get_kwargs(self) -> {str: ScopedValue}:
|
||||
"""
|
||||
Get the keyword arguments of current node.
|
||||
|
||||
- When node_type of current node is `CallCell`, `CallPrimitive` or `Tree`, keyword arguments are corresponding
|
||||
to kwargs of ast.Call which represents arguments to invoke forward method of cell-op or primitive-op.
|
||||
- When node_type of current node is `Python`, `Input` or `Output`, keyword arguments are don't-care.
|
||||
|
||||
Returns:
|
||||
A dict of str to instance of `ScopedValue`.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.rewrite import SymbolTree
|
||||
>>> from lenet import Lenet
|
||||
>>> net = Lenet()
|
||||
>>> stree = SymbolTree.create(net)
|
||||
>>> node = stree.get_node("conv1")
|
||||
>>> kwargs = node.get_kwargs()
|
||||
"""
|
||||
return self._node.get_kwargs()
|
||||
|
||||
def set_attribute(self, key: str, value):
|
||||
"""
|
||||
Set attribute of current node.
|
||||
|
||||
Args:
|
||||
key (str): Key of attribute.
|
||||
value (object): Value of attribute.
|
||||
|
||||
Raises:
|
||||
TypeError: If `key` is not a `str`.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.rewrite import SymbolTree
|
||||
>>> from lenet import Lenet
|
||||
>>> net = Lenet()
|
||||
>>> stree = SymbolTree.create(net)
|
||||
>>> node = stree.get_node("conv1")
|
||||
>>> node.set_attribute("channel", 3)
|
||||
"""
|
||||
Validator.check_value_type("key", key, [str], "Node attribute")
|
||||
self._node.set_attribute(key, value)
|
||||
|
||||
def get_attributes(self) -> {str: object}:
|
||||
"""
|
||||
Get all attributes of current node.
|
||||
|
||||
Returns:
|
||||
A dict of str to instance of object as attributes.
|
||||
"""
|
||||
return self._node.get_attributes()
|
||||
|
||||
def get_attribute(self, key: str):
|
||||
"""
|
||||
Get attribute of current node by key.
|
||||
|
||||
Args:
|
||||
key (str): Key of attribute.
|
||||
|
||||
Returns:
|
||||
A object as attribute, can be any type.
|
||||
|
||||
Raises:
|
||||
TypeError: If `key` is not a `str`.
|
||||
"""
|
||||
Validator.check_value_type("key", key, [str], "Node attribute")
|
||||
return self._node.get_attribute(key)
|
||||
|
|
|
@ -74,40 +74,7 @@ class SymbolTree:
|
|||
if v not in MsDtypes and not isinstance(v, ParamTypes):
|
||||
raise TypeError(f"For call-function Node, got unsupported kwarg value: {v}, type: {type(v)}")
|
||||
|
||||
def create_call_function(self, func, targets, *args, **kwargs):
|
||||
r"""
|
||||
Create a Node object and generate the execution code to insert into the source code.
|
||||
The source code calls the 'func' function with 'args' and' kwargs' as parameters.
|
||||
|
||||
Args:
|
||||
func (FunctionType): The function to be called.
|
||||
targets (list[str]): indicates the output name. As the output of the node in the source code.
|
||||
args (Union[MsDtypes, ParamTypes]): parameter name of the node. Used as a parameter to a code statement in
|
||||
source code. The default value is None, which means there is no parameter input in the cell.
|
||||
kwargs (dict{str,Union[MsDtypes, ParamTypes]}): The key type must be str,
|
||||
and the value must be value or type must be ParamTypes.
|
||||
The input parameter name used to describe the formal parameter with a keyword.
|
||||
Enter the name in the source code as the 'kwargs' in the statement expression.The default value is
|
||||
None, which means there is no 'kwargs' input.
|
||||
|
||||
Returns:
|
||||
An instance of `Node`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `func` is not FunctionType.
|
||||
TypeError: If `targets` is not `list`.
|
||||
TypeError: If the type of `targets` is not str.
|
||||
TypeError: If arg in `args` is not ParamType.
|
||||
TypeError: If key of `kwarg` is not a str or value of kwarg in `kwargs` is not ParamType.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.rewrite import SymbolTree
|
||||
>>> from lenet import Lenet
|
||||
>>> net = Lenet()
|
||||
>>> stree = SymbolTree.create(net)
|
||||
>>> node = stree.get_node("conv1")
|
||||
>>> new_node = stree.create_call_function(F.abs, ["x"], node)
|
||||
"""
|
||||
def create_call_function(self, func, targets, *args, **kwargs): # pylint: disable=C0111
|
||||
Validator.check_value_type("func", func, [FunctionType], "SymbolTree node")
|
||||
Validator.check_element_type_of_iterable("targets", targets, [str], "SymbolTree node")
|
||||
args_ = list(args)
|
||||
|
@ -122,19 +89,6 @@ class SymbolTree:
|
|||
return Node(self._symbol_tree.create_call_function(func, targets, args_, kwargs))
|
||||
|
||||
def get_handler(self) -> SymbolTreeImpl:
|
||||
"""
|
||||
Get handler of `SymbolTree` implementation.
|
||||
|
||||
Returns:
|
||||
An instance of `SymbolTree`.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.rewrite import SymbolTree
|
||||
>>> from lenet import Lenet
|
||||
>>> net = Lenet()
|
||||
>>> stree = SymbolTree.create(net)
|
||||
>>> handler = stree.get_handler()
|
||||
"""
|
||||
return self._symbol_tree
|
||||
|
||||
def nodes(self):
|
||||
|
@ -156,25 +110,6 @@ class SymbolTree:
|
|||
yield Node(node)
|
||||
|
||||
def get_node(self, node_name: str) -> Optional[Node]:
|
||||
"""
|
||||
Get node by `node_name`.
|
||||
|
||||
Args:
|
||||
node_name (str): A string represents name of node.
|
||||
|
||||
Returns:
|
||||
An instance of node if find else None.
|
||||
|
||||
Raises:
|
||||
TypeError: If `node_name` is not `str`.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.rewrite import SymbolTree
|
||||
>>> from lenet import Lenet
|
||||
>>> net = Lenet()
|
||||
>>> stree = SymbolTree.create(net)
|
||||
>>> node = stree.get_node("conv1")
|
||||
"""
|
||||
Validator.check_value_type("node_name", node_name, [str], "SymbolTree")
|
||||
node_impl = self._symbol_tree.get_node(node_name)
|
||||
if node_impl is None:
|
||||
|
@ -358,16 +293,6 @@ class SymbolTree:
|
|||
self._symbol_tree.dump()
|
||||
|
||||
def print_node_tabulate(self):
|
||||
"""
|
||||
Print node information of graph.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.rewrite import SymbolTree
|
||||
>>> from lenet import Lenet
|
||||
>>> net = Lenet()
|
||||
>>> stree = SymbolTree.create(net)
|
||||
>>> stree.print_node_tabulate()
|
||||
"""
|
||||
self._symbol_tree.print_node_tabulate()
|
||||
|
||||
def get_code(self) -> str:
|
||||
|
|
Loading…
Reference in New Issue