del rewrite api

This commit is contained in:
yuzhenhua 2022-12-05 15:03:40 +08:00
parent f292133c39
commit ca0df4f001
5 changed files with 391 additions and 315 deletions

View File

@ -54,26 +54,6 @@ MindSpore的ReWrite功能用于修改网络前向的计算过程对网络进
异常:
- **TypeError** - 参数 `network` 不是Cell类型对象。
.. py:method:: mindspore.rewrite.SymbolTree.create_call_function(func, targets, args, kwargs)
创建一个Node对象并生成执行代码插入源码中。源码中以 `args``kwargs` 为参数调用 `func` 函数。
参数:
- **func** (FunctionType) - 要被调用的函数。
- **targets** (list[str]) - 表示输出名称。在源代码中作为节点的输出。
- **args** (Union[MsDtypes, ParamTypes]) - 该节点的参数名称。用作源代码中代码语句的参数。默认为None表示 `cell` 没有参数输入。
- **kwargs** (dict{str,Union[MsDtypes, ParamTypes]}) - 键的类型必须是str值必须是MsDtypes或类型必须是ParamTypes。用来说明带有关键字的形参的输入参数名称。输入名称在源代码中作为语句表达式中的 `kwargs`。默认为None表示没有 `kwargs` 输入。
返回:
一个Node实例。
异常:
- **TypeError** - 如果参数 `func` 不是FunctionType类型。
- **TypeError** - 如果参数 `targets` 不是list类型。
- **TypeError** - 如果参数 `targets` 的成员不是str类型。
- **TypeError** - 如果参数 `args` 不是ParamType类型。
- **TypeError** - 如果参数 `kwarg``key` 不是str类型或者 `value` 不是ParamType类型。
.. py:method:: mindspore.rewrite.SymbolTree.dump()
`SymbolTree` 中network对应的ir图信息打印到屏幕。
@ -98,13 +78,6 @@ MindSpore的ReWrite功能用于修改网络前向的计算过程对网络进
返回:
strSymbolTree对应的源码字符串。
.. py:method:: mindspore.rewrite.SymbolTree.get_handler()
获取SymbolTree对应实现的handle。
返回:
SymbolTree对象。
.. py:method:: mindspore.rewrite.SymbolTree.get_network()
获取SymbolTree所对应的生成的网络对象。源码会保存到文件中默认的文件名为 `network_define.py`
@ -112,19 +85,6 @@ MindSpore的ReWrite功能用于修改网络前向的计算过程对网络进
返回:
根据SymbolTree生成的网络对象。
.. py:method:: mindspore.rewrite.SymbolTree.get_node(node_name: str)
获取节点名为 `node_name` 的节点。
参数:
- **node_name** (str) - 节点的名称。
返回:
如果找到则返回结果,否则返回 `None`
异常:
- **TypeError** - 如果 `node_name` 不是Node类型。
.. py:method:: mindspore.rewrite.SymbolTree.insert(position, node: Node)
在SymbolTree的 `position` 位置插入一个节点。 `position` 可以通过 `before``after` 来获得。
@ -148,10 +108,6 @@ MindSpore的ReWrite功能用于修改网络前向的计算过程对网络进
返回:
当前SymbolTree中节点的生成器。
.. py:method:: mindspore.rewrite.SymbolTree.print_node_tabulate()
打印当前SymbolTree的节点信息表格。
.. py:method:: mindspore.rewrite.SymbolTree.replace(old_node: Node, new_nodes: [Node])
使用新节点列表来替代旧节点。
@ -209,45 +165,6 @@ MindSpore的ReWrite功能用于修改网络前向的计算过程对网络进
- **TypeError** - 如果参数 `args` 不是ScopedValue类型。
- **TypeError** - 如果参数 `kwarg``key` 不是str类型或者 `value` 不是ScopedValue类型。
.. py:method:: mindspore.rewrite.Node.get_args()
获取当前节点的参数。
- 当前节点的 `node_type``CallCell``CallPrimitive``Tree` 时,返回值对应于 ast.Call 的 `args`,表示调用 `cell-op``primitive-op``forward` 方法的参数。
- 当前节点的 `node_type``Input` 时,返回值为函数参数的默认值。
- 当前节点的 `node_type``Output` 时,返回值为网络的返回值。
- 当前节点的 `node_type``Python` 时,没有实际含义,可以忽略。
返回:
`ScopedValue` 实例的列表。
.. py:method:: mindspore.rewrite.Node.get_attribute(key: str)
获取当前节点属性 `key` 的值。
参数:
- **key** (str) - 属性的名称。
返回:
属性值,可能是任意类型。
异常:
- **TypeError** - 如果参数 `key` 不是str类型。
.. py:method:: mindspore.rewrite.Node.get_attributes()
获取当前节点的所有属性。
返回:
返回一个包含属性名和属性值的字典。
.. py:method:: mindspore.rewrite.Node.get_handler()
获取节点具体实现的handle。
返回:
返回NodeImpl的实例。
.. py:method:: mindspore.rewrite.Node.get_inputs()
获取当前节点的拓扑序的输入节点。
@ -255,18 +172,6 @@ MindSpore的ReWrite功能用于修改网络前向的计算过程对网络进
返回:
Node的实例列表。
.. py:method:: mindspore.rewrite.Node.get_instance()
获取当前节点对应的 `operation` 实例。
- 如果当前节点的 `node_type``CallCell`该节点的实例是一个Cell的对象。
- 如果当前节点的 `node_type``CallPrimitive`该节点的实例是一个Primitive的对象。
- 如果当前节点的 `node_type``Tree`,该节点的实例是一个网络的对象。
- 如果当前节点的 `node_type``Python``Input``Output``CallMethod`该节点的实例为None。
返回:
当前节点的 `operation` 实例。
.. py:method:: mindspore.rewrite.Node.get_instance_type()
获取当前节点对应的 `operation` 实例类型。
@ -279,16 +184,6 @@ MindSpore的ReWrite功能用于修改网络前向的计算过程对网络进
返回:
当前节点的 `operation` 类型。
.. py:method:: mindspore.rewrite.Node.get_kwargs()
获取当前节点带 `key` 值的参数。
- 当前节点的 `node_type``CallCell``CallPrimitive``Tree` 时,关键字参数对应于 `ast.Call``kwargs`,表示调用 `cell-op``Primitive-op` 方法的参数。
- 当前节点的 `node_type``Python``Input``Output` 时,不关心关键字参数。
返回:
`key` 为str `value` 为ScopedValue的字典。
.. py:method:: mindspore.rewrite.Node.get_name()
获取当前节点的名称。当节点被插入到SymbolTree时节点的名称在SymbolTree中应该是唯一的。
@ -303,17 +198,6 @@ MindSpore的ReWrite功能用于修改网络前向的计算过程对网络进
返回:
NodeType当前节点的类型。
.. py:method:: mindspore.rewrite.Node.get_targets()
获取当前节点的输出名称。
- 当前节点的 `node_type``CallCell``CallPrimitive``CallMethod``Tree` 时, `target` 为字符串,表示单元操作或原始操作或函数调用的调用结果,它们对应于 `ast.Assign``targets`
- 当前节点的 `node_type``Input` 时, `targets` 应该只有一个元素,字符串代表函数的参数。
- 当前节点的 `node_type``Python``Output` 时, `target` 不需要关心。
返回:
节点输出的ScopedValue列表。
.. py:method:: mindspore.rewrite.Node.get_users()
按拓扑顺序获取当前节点的输出节点。
@ -352,17 +236,6 @@ MindSpore的ReWrite功能用于修改网络前向的计算过程对网络进
- **ValueError** - 如果参数 `out_idx` 超出了 `src_node` 的输出数量。
- **ValueError** - 当 `out_idx` 为None或者没有给 `out_idx` 赋值时,参数 `src_node` 有多个输出。
.. py:method:: mindspore.rewrite.Node.set_attribute(key: str, value)
设置当前节点的属性。
参数:
- **key** (str) - 属性的名称。
- **value** (object) - 属性值。
异常:
- **TypeError** - 如果参数 `key` 不是str类型。
.. py:class:: mindspore.rewrite.NodeType
NodeType表示Node的类型。

View File

@ -0,0 +1,195 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
This example mainly illustrates the usage of rewrite.
"""
from typing import OrderedDict
import numpy as np
import mindspore
from mindspore import Tensor, export
from mindspore.rewrite import SymbolTree, ScopedValue, Node, NodeType, Replacement, PatternEngine, PatternNode, \
TreeNodeHelper
import mindspore.nn as nn
import mindspore.ops as ops
class SubNet(nn.Cell):
"""子网络定义"""
def __init__(self):
super().__init__()
self.dense = nn.Dense(in_channels=32, out_channels=32, weight_init="ones")
self.mean = ops.ReduceMean(keep_dims=False)
self.conv1 = nn.Conv2d(1, 1, 1, stride=1)
def construct(self, x):
x = self.conv1(x)
x = self.dense(x)
return x
class Net(nn.Cell):
"""网络定义"""
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 1, 1, pad_mode='valid')
self.conv2 = nn.Conv2d(1, 1, 1, pad_mode='valid')
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
self.simnet = SubNet()
def construct(self, x):
"""网络的前向计算过程"""
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.simnet(x)
x = self.flatten(x)
return x
def create_stree(network):
"""创建SymbolTree"""
stree = SymbolTree.create(network)
stree.dump()
return stree
def insert_node(stree):
"""在网络中插入节点"""
for node in stree.nodes():
if node.get_name() == "conv2": # 在名称为'conv2'的节点前面插入新的节点
position = stree.before(node)
new_conv = nn.Conv2d(1, 1, 1)
new_conv_node = Node.create_call_cell(new_conv, targets=['x_1'], name='new_conv',
args=node.get_args())
stree.insert(position, new_conv_node)
break
# 使用新节点更新已有节点的输入
if new_conv_node is not None:
for node in stree.nodes():
if node.get_name() == "relu_1":
node.set_arg_by_node(0, new_conv_node)
break
def insert_node_to_subtree(stree):
"""在子网络中插入节点"""
def _insert_conv(stree: SymbolTree):
for node in stree.nodes():
if node.get_instance_type() == nn.Conv2d:
position = stree.after(node)
new_conv = nn.Conv2d(1, 1, 1)
new_conv_node = Node.create_call_cell(new_conv, targets=['x_1'], name='new_conv',
args=[ScopedValue.create_naming_value('x_1')])
stree.insert(position, new_conv_node)
break
# 在名称为'simnet'的子网络中插入新节点
for node in stree.nodes():
if node.get_node_type() == NodeType.Tree and node.get_name() == "simnet":
_insert_conv(TreeNodeHelper.get_sub_tree(node))
break
def delete_node(stree):
"""删除类型为nn.Flatten的节点"""
for node in stree.nodes():
if node.get_instance_type() == nn.Flatten:
for n in node.get_users():
n.set_arg(0, "x_7")
stree.erase_node(node)
break
def replace_node(stree):
"""替换网络中的节点"""
new_conv = nn.Conv2d(1, 1, 1)
new_conv_node = Node.create_call_cell(new_conv, [ScopedValue.create_naming_value("replace_conv")],
args=[ScopedValue.create_naming_value('x')])
for node in stree.nodes():
if node.get_name() == "conv1":
new_conv_node = stree.replace(node, [new_conv_node])
def pattern_replace(stree):
"""通过模式匹配的方式替换节点"""
class ConvReplacement(Replacement):
"""创建新节点类的实现"""
def build(self, pattern: PatternNode, is_chain_pattern: bool, matched: OrderedDict) -> [Node]:
assert is_chain_pattern
assert pattern.type() == nn.MaxPool2d
bn_node: Node = matched.get(pattern.name())
assert bn_node is not None
conv = nn.Conv2d(1, 1, 1)
conv_node = Node.create_call_cell(conv, ['x1'], bn_node.get_args(), bn_node.get_kwargs(),
name="pattern_conv")
return [conv_node]
class BnReplace(PatternEngine):
# 替换网络中nn.MaxPool2d类型的节点
def __init__(self):
super().__init__([nn.MaxPool2d], ConvReplacement())
bn_replace = BnReplace()
bn_replace.apply(stree)
def get_net(stree):
"""获取修改后的网络"""
return stree.get_network()
def get_code(stree):
"""获取修改后的网络代码"""
return stree.get_code()
def test_rewrite():
"""ReWrite测试函数"""
net = Net()
stree = create_stree(net)
print(f"origin code: {stree.get_code()}")
insert_node(stree)
print(f"after inser node code: {stree.get_code()}")
insert_node_to_subtree(stree)
print(f"after inser node to subtree code: {stree.get_code()}")
delete_node(stree)
print(f"after remove node code: {stree.get_code()}")
replace_node(stree)
print(f"after replace node code: {stree.get_code()}")
pattern_replace(stree)
print(f"after pattern replace node code: {stree.get_code()}")
inputs = Tensor(np.ones([1, 1, 32, 32]), mindspore.float32) # pylint: disable=E1102
new_net = get_net(stree)
source_code = get_code(stree)
print(source_code)
out = new_net(inputs)
print("out: ", out)
export(new_net, inputs, file_name="new_net", file_format="MINDIR")
if __name__ == "__main__":
test_rewrite()

View File

@ -0,0 +1,195 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
This example mainly illustrates the usage of rewrite.
"""
from typing import OrderedDict
import numpy as np
import mindspore
from mindspore import Tensor, export
from mindspore.rewrite import SymbolTree, ScopedValue, Node, NodeType, Replacement, PatternEngine, PatternNode, \
TreeNodeHelper
import mindspore.nn as nn
import mindspore.ops as ops
class SubNet(nn.Cell):
"""Subnetwork definition"""
def __init__(self):
super().__init__()
self.dense = nn.Dense(in_channels=32, out_channels=32, weight_init="ones")
self.mean = ops.ReduceMean(keep_dims=False)
self.conv1 = nn.Conv2d(1, 1, 1, stride=1)
def construct(self, x):
x = self.conv1(x)
x = self.dense(x)
return x
class Net(nn.Cell):
"""Network definition"""
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 1, 1, pad_mode='valid')
self.conv2 = nn.Conv2d(1, 1, 1, pad_mode='valid')
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
self.simnet = SubNet()
def construct(self, x):
"""The forward computing process of networks."""
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.simnet(x)
x = self.flatten(x)
return x
def create_stree(network):
"""Create SymbolTree"""
stree = SymbolTree.create(network)
stree.dump()
return stree
def insert_node(stree):
"""Insert a node into the network"""
for node in stree.nodes():
if node.get_name() == "conv2": # Insert a new node before the node named 'conv2'
position = stree.before(node)
new_conv = nn.Conv2d(1, 1, 1)
new_conv_node = Node.create_call_cell(new_conv, targets=['x_1'], name='new_conv',
args=node.get_args())
stree.insert(position, new_conv_node)
break
# Update the input of an existing node with a new node
if new_conv_node is not None:
for node in stree.nodes():
if node.get_name() == "relu_1":
node.set_arg_by_node(0, new_conv_node)
break
def insert_node_to_subtree(stree):
"""Inserting a node into a subnetwork"""
def _insert_conv(stree: SymbolTree):
for node in stree.nodes():
if node.get_instance_type() == nn.Conv2d:
position = stree.after(node)
new_conv = nn.Conv2d(1, 1, 1)
new_conv_node = Node.create_call_cell(new_conv, targets=['x_1'], name='new_conv',
args=[ScopedValue.create_naming_value('x_1')])
stree.insert(position, new_conv_node)
break
# Insert a new node in the subnet named 'simnet'
for node in stree.nodes():
if node.get_node_type() == NodeType.Tree and node.get_name() == "simnet":
_insert_conv(TreeNodeHelper.get_sub_tree(node))
break
def delete_node(stree):
"""Delete nodes of type nn.Flatten"""
for node in stree.nodes():
if node.get_instance_type() == nn.Flatten:
for n in node.get_users():
n.set_arg(0, "x_7")
stree.erase_node(node)
break
def replace_node(stree):
"""Replace nodes in the network"""
new_conv = nn.Conv2d(1, 1, 1)
new_conv_node = Node.create_call_cell(new_conv, [ScopedValue.create_naming_value("replace_conv")],
args=[ScopedValue.create_naming_value('x')])
for node in stree.nodes():
if node.get_name() == "conv1":
new_conv_node = stree.replace(node, [new_conv_node])
def pattern_replace(stree):
"""Replace nodes by pattern matching"""
class ConvReplacement(Replacement):
"""Create the implementation of a new node class."""
def build(self, pattern: PatternNode, is_chain_pattern: bool, matched: OrderedDict) -> [Node]:
assert is_chain_pattern
assert pattern.type() == nn.MaxPool2d
bn_node: Node = matched.get(pattern.name())
assert bn_node is not None
conv = nn.Conv2d(1, 1, 1)
conv_node = Node.create_call_cell(conv, ['x1'], bn_node.get_args(), bn_node.get_kwargs(),
name="pattern_conv")
return [conv_node]
class BnReplace(PatternEngine):
# Replace node of type nn.MaxPool2d in the network
def __init__(self):
super().__init__([nn.MaxPool2d], ConvReplacement())
bn_replace = BnReplace()
bn_replace.apply(stree)
def get_net(stree):
"""Get the modified network"""
return stree.get_network()
def get_code(stree):
"""Get the modified network code"""
return stree.get_code()
def test_rewrite():
"""ReWrite test function."""
net = Net()
stree = create_stree(net)
print(f"origin code: {stree.get_code()}")
insert_node(stree)
print(f"after inser node code: {stree.get_code()}")
insert_node_to_subtree(stree)
print(f"after inser node to subtree code: {stree.get_code()}")
delete_node(stree)
print(f"after remove node code: {stree.get_code()}")
replace_node(stree)
print(f"after replace node code: {stree.get_code()}")
pattern_replace(stree)
print(f"after pattern replace node code: {stree.get_code()}")
inputs = Tensor(np.ones([1, 1, 32, 32]), mindspore.float32)
new_net = get_net(stree)
source_code = get_code(stree)
print(source_code)
out = new_net(inputs)
print("out: ", out)
export(new_net, inputs, file_name="new_net", file_format="MINDIR")
if __name__ == "__main__":
test_rewrite()

View File

@ -99,12 +99,6 @@ class Node:
args, kwargs, name, is_sub_net))
def get_handler(self) -> NodeImpl:
"""
Get handler of node implementation.
Returns:
An instance of `NodeImpl`.
"""
return self._node
def get_inputs(self) -> ['Node']:
@ -209,27 +203,6 @@ class Node:
belong_symbol_tree.set_node_arg_by_node(self._node, arg_idx, src_node.get_handler(), out_idx)
def get_targets(self) -> [ScopedValue]:
"""
Get targets of current node.
- When node_type of current node is `CallCell`, `CallPrimitive`, `CallMethod` or `Tree`, `targets` are strings
represents invoke result of the cell-op or primitive-op or function-call which are corresponding to targets of
ast.Assign.
- When node_type of current node is Input, `targets` should have only one element which is a string represents
parameter of function.
- When node_type of current node is `Python` or `Output`, `targets` are don't-care.
Returns:
A list of instances of ScopedValue as targets of node.
Examples:
>>> from mindspore.rewrite import SymbolTree
>>> from lenet import Lenet
>>> net = Lenet()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> targets = node.get_targets()
"""
return self._node.get_targets()
def get_name(self) -> str:
@ -284,106 +257,21 @@ class Node:
return self._node.get_instance_type()
def get_instance(self):
"""
Get the instance of current node.
- When node_type of current node is `CallCell`, instance is an instance of Cell.
- When node_type of current node is `CallPrimitive`, instance is an instance of primitive.
- When node_type of current node is `Tree`, instance is an instance of network-cell.
- When node_type of current node is `Python`, `Input`, `Output` or `CallMethod`, instance should be None.
Returns:
A object represents corresponding instance of current node.
"""
return self._node.get_instance()
def get_args(self) -> [ScopedValue]:
"""
Get the arguments of current node.
- When `node_type` of current node is `CallCell`, `CallPrimitive` or `Tree`, arguments are corresponding to args
of ast.Call which represents arguments to invoke forward method of cell-op or primitive-op.
- When `node_type` of current node is `Input`, arguments represents default-value of argument of function.
- When `node_type` of current node is `Output`, arguments represents the return values of network.
- When `node_type` of current node is `Python`, arguments are don't-care.
Returns:
A list of instances of `ScopedValue`.
Examples:
>>> from mindspore.rewrite import SymbolTree
>>> from lenet import Lenet
>>> net = Lenet()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> args = node.get_args()
"""
return self._node.get_args()
def get_kwargs(self) -> {str: ScopedValue}:
"""
Get the keyword arguments of current node.
- When node_type of current node is `CallCell`, `CallPrimitive` or `Tree`, keyword arguments are corresponding
to kwargs of ast.Call which represents arguments to invoke forward method of cell-op or primitive-op.
- When node_type of current node is `Python`, `Input` or `Output`, keyword arguments are don't-care.
Returns:
A dict of str to instance of `ScopedValue`.
Examples:
>>> from mindspore.rewrite import SymbolTree
>>> from lenet import Lenet
>>> net = Lenet()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> kwargs = node.get_kwargs()
"""
return self._node.get_kwargs()
def set_attribute(self, key: str, value):
"""
Set attribute of current node.
Args:
key (str): Key of attribute.
value (object): Value of attribute.
Raises:
TypeError: If `key` is not a `str`.
Examples:
>>> from mindspore.rewrite import SymbolTree
>>> from lenet import Lenet
>>> net = Lenet()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> node.set_attribute("channel", 3)
"""
Validator.check_value_type("key", key, [str], "Node attribute")
self._node.set_attribute(key, value)
def get_attributes(self) -> {str: object}:
"""
Get all attributes of current node.
Returns:
A dict of str to instance of object as attributes.
"""
return self._node.get_attributes()
def get_attribute(self, key: str):
"""
Get attribute of current node by key.
Args:
key (str): Key of attribute.
Returns:
A object as attribute, can be any type.
Raises:
TypeError: If `key` is not a `str`.
"""
Validator.check_value_type("key", key, [str], "Node attribute")
return self._node.get_attribute(key)

View File

@ -74,40 +74,7 @@ class SymbolTree:
if v not in MsDtypes and not isinstance(v, ParamTypes):
raise TypeError(f"For call-function Node, got unsupported kwarg value: {v}, type: {type(v)}")
def create_call_function(self, func, targets, *args, **kwargs):
r"""
Create a Node object and generate the execution code to insert into the source code.
The source code calls the 'func' function with 'args' and' kwargs' as parameters.
Args:
func (FunctionType): The function to be called.
targets (list[str]): indicates the output name. As the output of the node in the source code.
args (Union[MsDtypes, ParamTypes]): parameter name of the node. Used as a parameter to a code statement in
source code. The default value is None, which means there is no parameter input in the cell.
kwargs (dict{str,Union[MsDtypes, ParamTypes]}): The key type must be str,
and the value must be value or type must be ParamTypes.
The input parameter name used to describe the formal parameter with a keyword.
Enter the name in the source code as the 'kwargs' in the statement expression.The default value is
None, which means there is no 'kwargs' input.
Returns:
An instance of `Node`.
Raises:
TypeError: If `func` is not FunctionType.
TypeError: If `targets` is not `list`.
TypeError: If the type of `targets` is not str.
TypeError: If arg in `args` is not ParamType.
TypeError: If key of `kwarg` is not a str or value of kwarg in `kwargs` is not ParamType.
Examples:
>>> from mindspore.rewrite import SymbolTree
>>> from lenet import Lenet
>>> net = Lenet()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> new_node = stree.create_call_function(F.abs, ["x"], node)
"""
def create_call_function(self, func, targets, *args, **kwargs): # pylint: disable=C0111
Validator.check_value_type("func", func, [FunctionType], "SymbolTree node")
Validator.check_element_type_of_iterable("targets", targets, [str], "SymbolTree node")
args_ = list(args)
@ -122,19 +89,6 @@ class SymbolTree:
return Node(self._symbol_tree.create_call_function(func, targets, args_, kwargs))
def get_handler(self) -> SymbolTreeImpl:
"""
Get handler of `SymbolTree` implementation.
Returns:
An instance of `SymbolTree`.
Examples:
>>> from mindspore.rewrite import SymbolTree
>>> from lenet import Lenet
>>> net = Lenet()
>>> stree = SymbolTree.create(net)
>>> handler = stree.get_handler()
"""
return self._symbol_tree
def nodes(self):
@ -156,25 +110,6 @@ class SymbolTree:
yield Node(node)
def get_node(self, node_name: str) -> Optional[Node]:
"""
Get node by `node_name`.
Args:
node_name (str): A string represents name of node.
Returns:
An instance of node if find else None.
Raises:
TypeError: If `node_name` is not `str`.
Examples:
>>> from mindspore.rewrite import SymbolTree
>>> from lenet import Lenet
>>> net = Lenet()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
"""
Validator.check_value_type("node_name", node_name, [str], "SymbolTree")
node_impl = self._symbol_tree.get_node(node_name)
if node_impl is None:
@ -358,16 +293,6 @@ class SymbolTree:
self._symbol_tree.dump()
def print_node_tabulate(self):
"""
Print node information of graph.
Examples:
>>> from mindspore.rewrite import SymbolTree
>>> from lenet import Lenet
>>> net = Lenet()
>>> stree = SymbolTree.create(net)
>>> stree.print_node_tabulate()
"""
self._symbol_tree.print_node_tabulate()
def get_code(self) -> str: