diff --git a/docs/api/api_python/mindspore.rewrite.rst b/docs/api/api_python/mindspore.rewrite.rst index 0f62b6c6328..a085a388f72 100644 --- a/docs/api/api_python/mindspore.rewrite.rst +++ b/docs/api/api_python/mindspore.rewrite.rst @@ -4,48 +4,49 @@ mindspore.rewrite .. py:class:: mindspore.rewrite.SymbolTree(handler: SymbolTreeImpl) - SymbolTree通常对应于网络的forward方法。 + SymbolTree通常对应于网络的前向计算过程。 参数: - - **network** (Cell) - 要重写的网络。现在只支持Cell类型的网络。 + - **handler** (SymbolTreeImpl) - SymbolTree内部实现实例。 异常: - - **RuntimeError** - 如果 `network` 不是Cell对象。 - - **RuntimeError** - 如果 `network` 中包含不支持解析和优化的ast节点类型。 + - **RuntimeError** - `network` 不是Cell对象。 + - **RuntimeError** - `network` 中包含不支持解析和优化的ast节点类型。 .. py:method:: mindspore.rewrite.SymbolTree.after(node: Node) 获取插入位置,位置为 `node` 之后。 - 返回值用于指示插入节点的位置,它指示在源代码中的位置,而不是在拓扑顺序中的位置。我们不需要关心 `Position` 是什么,只需将其视为处理程序并将其用作 `SymbolTree` 的插入接口的参数即可。 + 返回值用于指示插入节点的位置,它指示在源代码中的位置,而不是在拓扑顺序中的位置。不需要关心 Position是什么,只需将其视为处理程序并将其用作SymbolTree的插入接口的参数。 参数: - - **node** (Node) - 指定插入位置在哪个节点之后。 + - **node** (Node) - 指定插入位置在哪个节点之后,可以是Node或者Node的名称。 返回: Position,指定插入节点的位置。 异常: - - **TypeError** - 如果参数不是Node类型。 + - **TypeError** - 参数不是Node类型。 .. py:method:: mindspore.rewrite.SymbolTree.before(node: Node) 与after的区别是,该接口返回的位置为 `node` 之前。 + 返回值用于指示插入节点的位置,它指示在源代码中的位置,而不是在拓扑顺序中的位置。不需要关心 `Position` 是什么,只需将其视为处理程序并将其用作 `SymbolTree` 的插入接口的参数。 参数: - - **node** (Node) - 指定插入位置在哪个节点之前。 + - **node** (Node) - 指定插入位置在哪个节点之前,可以是Node或者Node的名称。 返回: Position,指定插入节点的位置。 异常: - - **TypeError** - 如果参数不是Node类型。 + - **TypeError** - 参数不是Node类型。 .. py:method:: mindspore.rewrite.SymbolTree.create(network) - 根据传入的 `network` 创建一个SymbolTree对象。 + 根据传入的 `network` 创建SymbolTree对象。 参数: - - **network** (Cell) - 要重写的网络。现在只支持Cell类型的网络。 + - **network** (Cell) - 重写的网络。 返回: SymbolTree,基于 `network` 创建的符号树。 @@ -82,7 +83,7 @@ mindspore.rewrite 删除SymbolTree中的一个节点。被删除的节点必须不被其他节点依赖。 参数: - - **node** (Node) - 被删除的节点。 + - **node** (Node) - 被删除的节点。可以是Node或者Node的名称。 返回: 如果 `node` 属于当前的SymbolTree则返回被删除节点。否则返回None。 @@ -99,7 +100,7 @@ mindspore.rewrite .. py:method:: mindspore.rewrite.SymbolTree.get_handler() - 获取SymbolTree所对应的实现句柄。 + 获取SymbolTree对应实现的handle。 返回: SymbolTree对象。 @@ -129,7 +130,7 @@ mindspore.rewrite 如果找到则返回结果,否则返回 `None`。 异常: - - **TypeError** - 如果参数不是Node类型。 + - **TypeError** - 如果 `node_name` 不是Node类型。 .. py:method:: mindspore.rewrite.SymbolTree.get_saved_file_name() @@ -217,28 +218,27 @@ mindspore.rewrite 节点是表达网络中源代码的一种数据结构。 - 在大多数情况下,Node表示一个向前计算的的运算,它可以是Cell的实例、Primitive的实例或可调用的方法。 - - 下面提到的NodeImpl是Node的实现,它不是Rewrite的接口。Rewrite建议调用Node的特定 `create` 方法来实例化Node的实例,例如 `create_call_cell`,而不是直接调用Node的构造函数,所以不要关心NodeImpl是什么,只需要看做一个句柄即可。 + 在大多数情况下,Node表示前向计算的的运算,它可以是Cell的实例、Primitive的实例或可调用的方法。 参数: - - **node** (NodeImpl) - SymbolTree中节点的具体实现类的实例。 + - **node** (NodeImpl) - `NodeImpl` 的handle。NodeImpl是Node的实现,不是Rewrite的接口。Rewrite建议调用Node的特定 `create` 方法来实例化Node的实例,例如 `create_call_cell`,而不直接调用Node的构造函数,不需关心NodeImpl是什么,只需作为handle看待。 .. py:method:: mindspore.rewrite.Node.create_call_cell(cell: Cell, targets: [Union[ScopedValue, str]], args: [ScopedValue] = None, kwargs: {str: ScopedValue}=None, name: str = "", is_sub_net: bool = False) :staticmethod: - 通过该接口可以根据 `cell` 对象创建一个Node实例。节点对应的源代码格式: ``targets = self.name(*args, **kwargs)``。 + 通过该接口可以根据 `cell` 对象创建一个Node实例。节点对应的源代码格式: + ``targets = self.name(*args, **kwargs)``。 参数: - **cell** (Cell) - 该节点对应的前向计算的Cell对象。 - **targets** (list[ScopedValue]) - 表示输出名称。在源代码中作为节点的输出。Rewrite将在插入节点时检查并确保每个目标的唯一性。 - - **args** (list[ScopedValue]) - 该节点的参数名称。用作源代码中代码语句的参数。默认为None表示 `cell` 没有参数输入。Rewrite将在插入节点时检查并确保每个 `arg` 的唯一性。 - - **kwargs** (dict) - 键的类型必须是str,值的类型必须是ScopedValue。用来说明带有关键字的形参的输入参数名称。输入名称在源代码中作为语句表达式中的 `kwargs`。默认为None,表示 `cell` 没有 `kwargs` 输入。Rewrite将在插入节点时检查并确保每个 `kwarg` 的唯一性。 - - **name** (str) - 表示节点的名称。用作源代码中的字段名称。默认为无。当名称为无时,ReWrite将根据 `target` 生成一个默认名称。Rewrite将在插入节点时检查并确保名称的唯一性。 - - **is_sub_net** (bool) - 表示 `cell` 是否是一个网络。如果 `is_sub_net` 为真,Rewrite将尝试将 `cell` 解析为TreeNode,否则为CallCell节点。默认为False。 + - **args** (list[ScopedValue]) - 该节点的参数名称。用作源代码中代码语句的参数。表示 `cell` 没有参数输入。Rewrite将在插入节点时检查并确保每个 `arg` 的唯一性。默认值:None。 + - **kwargs** (dict) - 键的类型必须是str,值的类型必须是ScopedValue。用来说明带有关键字的形参的输入参数名称。输入名称在源代码中作为语句表达式中的 `kwargs`。表示 `cell` 没有 `kwargs` 输入。Rewrite将在插入节点时检查并确保每个 `kwarg` 的唯一性。默认值:None。 + - **name** (str) - 表示节点的名称。用作源代码中的字段名称。当名称为无时,ReWrite将根据 `target` 生成一个默认名称。Rewrite将在插入节点时检查并确保名称的唯一性。默认值:None。 + - **is_sub_net** (bool) - 表示 `cell` 是否是一个网络。如果 `is_sub_net` 为真,Rewrite将尝试将 `cell` 解析为TreeNode,否则为CallCell节点。默认值:False。 返回: - 一个Node实例。 + Node实例。 异常: - **TypeError** - 如果参数 `cell` 不是Cell类型。 @@ -253,7 +253,7 @@ mindspore.rewrite - 当前节点的 `node_type` 为 `CallCell`、 `CallPrimitive` 或 `Tree` 时,返回值对应于 ast.Call 的 `args`,表示调用 `cell-op` 或 `primitive-op` 的 `forward` 方法的参数。 - 当前节点的 `node_type` 为 `Input` 时,返回值为函数参数的默认值。 - - 当前节点的 `node_type` 为 `Output` 时,返回值对应网络的返回值。 + - 当前节点的 `node_type` 为 `Output` 时,返回值为网络的返回值。 - 当前节点的 `node_type` 为 `Python` 时,没有实际含义,可以忽略。 返回: @@ -281,14 +281,14 @@ mindspore.rewrite .. py:method:: mindspore.rewrite.Node.get_handler() - 获取节点具体实现的句柄。 + 获取节点具体实现的handle。 返回: 返回NodeImpl的实例。 .. py:method:: mindspore.rewrite.Node.get_inputs() - 获取当前节点的输入节点。 + 获取当前节点的拓扑序的输入节点。 返回: Node的实例列表。 @@ -309,9 +309,9 @@ mindspore.rewrite 获取当前节点对应的 `operation` 实例类型。 - - 如果当前节点的 `node_type` 是 `CallCell`,该节点是一个Cell对象。 - - 如果当前节点的 `node_type` 是 `CallPrimitive`,该节点的是一个Primitive对象。 - - 如果当前节点的 `node_type` 是 `Tree`,该节点的类型是一个网络。 + - 如果当前节点的 `node_type` 是 `CallCell`,该节点是Cell对象。 + - 如果当前节点的 `node_type` 是 `CallPrimitive`,该节点的是Primitive对象。 + - 如果当前节点的 `node_type` 是 `Tree`,该节点的类型是网络。 - 如果当前节点的 `node_type` 是 `Python`、 `Input`、 `Output`、 `CallMethod`,该节点的类型为NoneType。 返回: @@ -332,7 +332,7 @@ mindspore.rewrite 获取当前节点的名称。当节点被插入到SymbolTree时,节点的名称在SymbolTree中应该是唯一的。 返回: - str,节点的名称。 + 节点的名称,类型为str。 .. py:method:: mindspore.rewrite.Node.get_next() @@ -385,14 +385,14 @@ mindspore.rewrite - **TypeError** - 如果参数 `index` 不是int类型。 - **TypeError** - 如果参数 `arg` 不是str或者ScopedValue类型。 - .. py:method:: mindspore.rewrite.Node.set_arg_by_node(arg_idx: int, src_node: 'Node', out_idx: Optional[int] = None) + .. py:method:: mindspore.rewrite.Node.set_arg_by_node(arg_idx: int, src_node: Node, out_idx: Optional[int] = None) 将另一个节点设置为当前节点的输入。 参数: - **arg_idx** (int) - 要设置的参数索引。 - - **src_node** (Node) - 作为输入的节点。 - - **out_idx** (int,optional) - 指定输入节点的哪个输出作为当前节点输入,默认是None,则取第一个输出。 + - **src_node** (Node) - 输入的节点。 + - **out_idx** (int,optional) - 指定输入节点的哪个输出作为当前节点输入,则取第一个输出。默认值:None。 异常: - **RuntimeError** - 如果 `src_node` 不属于当前的SymbolTree。 @@ -402,7 +402,7 @@ mindspore.rewrite - **TypeError** - 如果参数 `src_node` 不是Node类型。 - **TypeError** - 如果参数 `out_idx` 不是int类型。 - **ValueError** - 如果参数 `out_idx` 超出了 `src_node` 的输出数量。 - - **ValueError** - 如果参数 `src_node` 当 `out_idx` 为None或者没有给 `out_idx` 赋值时,有多个输出。 + - **ValueError** - 当 `out_idx` 为None或者没有给 `out_idx` 赋值时,参数 `src_node` 有多个输出。 .. py:method:: mindspore.rewrite.Node.set_attribute(key: str, value) @@ -432,21 +432,21 @@ mindspore.rewrite ScopedValue表示具有完整范围的值。 - ScopedValue用于表示:一个左值,如赋值语句的目标,或可调用对象,如调用语句的 `func`,或右值,如赋值语句的 `args` 和 `kwargs`。 + ScopedValue用于表示:左值,如赋值语句的目标,或可调用对象,如调用语句的 `func`,或右值,如赋值语句的 `args` 和 `kwargs`。 参数: - - **arg_type** (ValueType) - 表示当前值的类型。 - - **scope** (str) - 一个字符串表示当前值的范围。以"self.var1"为例,这个var1的作用域是"self"。 - - **value** - 当前ScopedValue中保存的值。值的类型对应于 `arg_type`。 + - **arg_type** (ValueType) - 当前值的类型。 + - **scope** (str) - 字符串表示当前值的范围。以"self.var1"为例,这个var1的作用域是"self"。默认值:""。 + - **value** - 当前ScopedValue中保存的值。值的类型对应于 `arg_type`。默认值:None。 .. py:method:: mindspore.rewrite.ScopedValue.create_name_values(names: Union[list, tuple], scopes: Union[list, tuple] = None) :staticmethod: - 创建一个ScopedValue的列表。 + 创建ScopedValue的列表。 参数: - - **names** (list[str] or tuple[str]) – str 的列表或元组表示引用变量的名称。 - - **scopes** (list[str] or tuple[str]) – str 的列表或元组表示引用变量的范围,默认值None表示没有指定作用范围。 + - **names** (list[str] or tuple[str]) – 引用变量的名称,类型为str的列表或元组。 + - **scopes** (list[str] or tuple[str]) – 引用变量的范围,类型为str的列表或元组。表示没有指定作用范围。默认值:None。 返回: ScopedValue的实例列表。 @@ -462,7 +462,7 @@ mindspore.rewrite 参数: - **name** (str) – 表示变量的字符串。 - - **scope** (str) – 表示变量范围的字符串,默认值为空字符串,表示没有指定作用范围。 + - **scope** (str) – 表示变量范围的字符串,表示没有指定作用范围。默认值:空字符串。 返回: ScopedValue的实例。 @@ -490,11 +490,11 @@ mindspore.rewrite .. py:class:: mindspore.rewrite.PatternEngine(pattern: Union[PatternNode, List], replacement: Replacement = None) - PatternEngine实现了如何通过PattenNode修改SymbolTree。 + PatternEngine通过PattenNode修改SymbolTree。 参数: - **pattern** (Union[PatternNode, List]) - PatternNode的实例或用于构造 `Pattent` 的Cell类型列表。 - - **replacement** (callable) - 生成新节点的接口实现,如果为None则不进行任何匹配操作。 + - **replacement** (callable) - 生成新节点的接口实现。 .. py:method:: mindspore.rewrite.PatternEngine.apply(stree: SymbolTree) @@ -525,12 +525,12 @@ mindspore.rewrite 参数: - **pattern_node_name** (str) - 节点名称。 - - **match_type** (Type) - 当前节点的匹配类型。 - - **inputs** (list[PatternNode]) - 当前节点的输入节点。 + - **match_type** (Type) - 当前节点的匹配类型。默认值:None。 + - **inputs** (list[PatternNode]) - 当前节点的输入节点。默认值:None。 .. py:method:: mindspore.rewrite.PatternNode.add_input(node) - 为当前节点添加一个输入。 + 为当前节点添加输入。 参数: - **node** (PatternNode) - 新增的输入节点。 @@ -541,10 +541,10 @@ mindspore.rewrite .. py:method:: mindspore.rewrite.PatternNode.create_pattern_from_list(type_list: []) :staticmethod: - 使用一个类型的列表来创建一个Pattern。 + 使用类型的列表来创建Pattern。 参数: - - **type_list** (list[type]) - 类型列表,当前支持Cell和Primitive。 + - **type_list** (list[type]) - 类型列表。 返回: 根据列表生成的模式的根节点。 @@ -555,7 +555,7 @@ mindspore.rewrite .. py:method:: mindspore.rewrite.PatternNode.create_pattern_from_node(node: Node) :staticmethod: - 根据一个节点及其输入创建一个Pattern。 + 根据节点及其输入创建Pattern。 参数: - **node** (Node) - 要修改的节点。 @@ -635,7 +635,7 @@ mindspore.rewrite 参数: - **pattern** (PatternNode) - 当前模式的根节点。 - - **is_chain_pattern** (bool) - 标记模式是链模式或树模式。 + - **is_chain_pattern** (bool) - 标记,标记模式是链模式或树模式。 - **matched** (OrderedDict) - 匹配结果,从名称映射到节点的字典。 返回: @@ -645,7 +645,7 @@ mindspore.rewrite TreeNodeHelper用于在从Tree类型节点获取 `symbol_tree` 时打破循环引用。 - TreeNodeHelper提供了一个静态方法 `get_sub_tree` 用于从Tree类型节点获取 `symbol_tree`。 + TreeNodeHelper提供了静态方法 `get_sub_tree` 用于从Tree类型节点获取 `symbol_tree`。 .. py:method:: mindspore.rewrite.TreeNodeHelper.get_sub_tree(node: Node) :staticmethod: @@ -653,11 +653,11 @@ mindspore.rewrite 获取Tree类型节点的 `symbol_tree`。 参数: - - **node** (Node) - 一个可以持有子符号树的节点。 + - **node** (Node) - 可以持有子符号树的节点。 返回: Tree节点中的SymbolTree对象。注意节点的 `symbol_tree` 可能是None,在这种情况下,方法将返回None。 异常: - - **RuntimeError** - 如果参数 `node` 的 `node_type` 不是Tree类型。 + - **RuntimeError** - 如果参数 `node` 不是 NodeType.Tree类型。 - **TypeError** - 如果参数 `node` 不是Node类型实例。 diff --git a/mindspore/python/mindspore/rewrite/api/node.py b/mindspore/python/mindspore/rewrite/api/node.py index b4163461f1b..687d50f314d 100644 --- a/mindspore/python/mindspore/rewrite/api/node.py +++ b/mindspore/python/mindspore/rewrite/api/node.py @@ -32,13 +32,11 @@ class Node: For the most part, Node represents an operator invoking in forward which could be an instance of `Cell`, an instance of `Primitive` or a callable method. - `NodeImpl` mentioned below is implementation of `Node` which is not an interface of Rewrite. Rewrite recommend - invoking specific create method of `Node` to instantiate an instance of Node such as `create_call_cell` rather than - invoking constructor of `Node` directly, so don't care about what is `NodeImpl` and use its instance just as a - handler. - Args: - node (NodeImpl): A handler of `NodeImpl`. + node (NodeImpl): A handler of `NodeImpl`. `NodeImpl` mentioned below is implementation of `Node` which is not + an interface of Rewrite. Rewrite recommend invoking specific create method of `Node` to instantiate an instance + of Node such as `create_call_cell` rather than invoking constructor of `Node` directly, so don't care about + what is `NodeImpl` and use its instance just as a handler. """ def __init__(self, node: NodeImpl): @@ -142,6 +140,14 @@ class Node: Returns: A list of nodes represents users. + + Examples: + >>> from mindspore.rewrite import SymbolTree + >>> from lenet import Lenet + >>> net = Lenet() + >>> stree = SymbolTree.create(net) + >>> node = stree.get_node("conv1") + >>> users = node.get_users() """ belong_symbol_tree: SymbolTreeImpl = self._node.get_belong_symbol_tree() if belong_symbol_tree is None: @@ -164,6 +170,14 @@ class Node: Raises: TypeError: If `index` is not a `int` number. TypeError: If the type of `arg` is not in [`ScopedValue`, `str`]. + + Examples: + >>> from mindspore.rewrite import SymbolTree + >>> from lenet import Lenet + >>> net = Lenet() + >>> stree = SymbolTree.create(net) + >>> node = stree.get_node("conv1") + >>> node.set_arg(0, "x") """ Validator.check_value_type("index", index, [int], "Node") Validator.check_value_type("arg", arg, [ScopedValue, str], "Node") @@ -175,7 +189,7 @@ class Node: def set_arg_by_node(self, arg_idx: int, src_node: 'Node', out_idx: Optional[int] = None): """ - Set argument of current node by another `Node`. + Set argument of current node by another Node. Args: arg_idx (int): Indicate which input being modified. @@ -192,6 +206,15 @@ class Node: TypeError: If `out_idx` is not a `int` number. ValueError: If `out_idx` is out of range. ValueError: If `src_node` has multi-outputs while `out_idx` is None or `out_idx` is not offered. + + Examples: + >>> from mindspore.rewrite import SymbolTree + >>> from lenet import Lenet + >>> net = Lenet() + >>> stree = SymbolTree.create(net) + >>> src_node = stree.get_node("conv1") + >>> dst_node = stree.get_node("conv2") + >>> dst_node.set_arg_by_node(0, src_node) """ Validator.check_value_type("arg_idx", arg_idx, [int], "Node") Validator.check_value_type("src_node", src_node, [Node], "Node") @@ -216,6 +239,14 @@ class Node: Returns: A list of instances of ScopedValue as targets of node. + + Examples: + >>> from mindspore.rewrite import SymbolTree + >>> from lenet import Lenet + >>> net = Lenet() + >>> stree = SymbolTree.create(net) + >>> node = stree.get_node("conv1") + >>> targets = node.get_targets() """ return self._node.get_targets() @@ -227,6 +258,14 @@ class Node: Returns: A string as name of node. + + Examples: + >>> from mindspore.rewrite import SymbolTree + >>> from lenet import Lenet + >>> net = Lenet() + >>> stree = SymbolTree.create(net) + >>> node = stree.get_node("conv1") + >>> name = node.get_name() """ return self._node.get_name() @@ -236,6 +275,14 @@ class Node: Returns: A NodeType as node_type of node. + + Examples: + >>> from mindspore.rewrite import SymbolTree + >>> from lenet import Lenet + >>> net = Lenet() + >>> stree = SymbolTree.create(net) + >>> node = stree.get_node("conv1") + >>> node_type = node.get_node_type() """ return self._node.get_node_type() @@ -272,14 +319,22 @@ class Node: """ Get the arguments of current node. - - When node_type of current node is `CallCell`, `CallPrimitive` or `Tree`, arguments are corresponding to args + - When `node_type` of current node is `CallCell`, `CallPrimitive` or `Tree`, arguments are corresponding to args of ast.Call which represents arguments to invoke forward method of cell-op or primitive-op. - - When node_type of current node is `Input`, arguments represents default-value of argument of function. - - When node_type of current node is `Output`, arguments represents return values. - - When node_type of current node is `Python`, arguments are don't-care. + - When `node_type` of current node is `Input`, arguments represents default-value of argument of function. + - When `node_type` of current node is `Output`, arguments represents the return values of network. + - When `node_type` of current node is `Python`, arguments are don't-care. Returns: A list of instances of `ScopedValue`. + + Examples: + >>> from mindspore.rewrite import SymbolTree + >>> from lenet import Lenet + >>> net = Lenet() + >>> stree = SymbolTree.create(net) + >>> node = stree.get_node("conv1") + >>> args = node.get_args() """ return self._node.get_args() @@ -293,6 +348,14 @@ class Node: Returns: A dict of str to instance of `ScopedValue`. + + Examples: + >>> from mindspore.rewrite import SymbolTree + >>> from lenet import Lenet + >>> net = Lenet() + >>> stree = SymbolTree.create(net) + >>> node = stree.get_node("conv1") + >>> kwargs = node.get_kwargs() """ return self._node.get_kwargs() @@ -306,6 +369,14 @@ class Node: Raises: TypeError: If `key` is not a `str`. + + Examples: + >>> from mindspore.rewrite import SymbolTree + >>> from lenet import Lenet + >>> net = Lenet() + >>> stree = SymbolTree.create(net) + >>> node = stree.get_node("conv1") + >>> node.set_attribute("channel", 3) """ Validator.check_value_type("key", key, [str], "Node attribute") self._node.set_attribute(key, value) @@ -327,7 +398,7 @@ class Node: key (str): Key of attribute. Returns: - A object as attribute. + A object as attribute, can be any type. Raises: TypeError: If `key` is not a `str`. diff --git a/mindspore/python/mindspore/rewrite/api/pattern_engine.py b/mindspore/python/mindspore/rewrite/api/pattern_engine.py index c9d7a952c37..5ae95035337 100644 --- a/mindspore/python/mindspore/rewrite/api/pattern_engine.py +++ b/mindspore/python/mindspore/rewrite/api/pattern_engine.py @@ -33,8 +33,8 @@ class PatternNode: Args: pattern_node_name (str): Name of current node. - match_type (Type): A type represents what type would be matched of current node. - inputs (list[PatternNode]): Input nodes of current node. + match_type (Type): A type represents what type would be matched of current node. Default: None. + inputs (list[PatternNode]): Input nodes of current node. Default: None. """ def __init__(self, pattern_node_name: str, match_type: Type = Type[None], inputs: ['PatternNode'] = None): @@ -166,13 +166,13 @@ class PatternNode: def name(self) -> str: """ - Getter of name. + Getter of PatternNode name. """ return self._name def type(self): """ - Getter of type. + Getter of PatternNode type. """ return self._type @@ -192,8 +192,17 @@ class VarNode(PatternNode): class Replacement(abc.ABC): """ Interface of replacement function. - """ + Examples: + >>> from mindspore.rewrite import Replacement, Node + >>> from mindspore.nn import nn + >>> class BnReplacement(Replacement): + ... def build(self, pattern, is_chain_pattern: bool, matched): + ... bn_node: Node = matched.get(pattern.name()) + ... conv = nn.Conv2d(16, 16, 3) + ... conv_node = Node.create_call_cell(conv, ['x1'], bn_node.get_args(), bn_node.get_kwargs()) + ... return [conv_node] + """ @abc.abstractmethod def build(self, pattern: PatternNode, is_chain_pattern: bool, matched: OrderedDict) -> [Node]: """ diff --git a/mindspore/python/mindspore/rewrite/api/scoped_value.py b/mindspore/python/mindspore/rewrite/api/scoped_value.py index a3ff4ba6415..9b2297cbdd0 100644 --- a/mindspore/python/mindspore/rewrite/api/scoped_value.py +++ b/mindspore/python/mindspore/rewrite/api/scoped_value.py @@ -73,6 +73,10 @@ class ScopedValue: Returns: An instance of `ScopedValue`. + + Examples: + >>> from mindspore.rewrite import ScopedValue + >>> variable = ScopedValue.create_variable_value(2) """ if isinstance(value, int): return cls(ValueType.IntValue, "", value) @@ -108,6 +112,10 @@ class ScopedValue: Raises: TypeError: If `name` is not `str`. TypeError: If `scope` is not `str`. + + Examples: + >>> from mindspore.rewrite import ScopedValue + >>> variable = ScopedValue.create_naming_value("conv", "self") """ Validator.check_value_type("name", name, [str], "ScopedValue") Validator.check_value_type("scope", scope, [str], "ScopedValue") @@ -129,6 +137,10 @@ class ScopedValue: RuntimeError: If the length of names is not equal to the length of scopes when scopes are not None. TypeError: If `names` is not `list` or `tuple` and name in `names` is not `str`. TypeError: If `scopes` is not `list` or `tuple` and scope in `scopes` is not `str`. + + Examples: + >>> from mindspore.rewrite import ScopedValue + >>> variables = ScopedValue.create_name_values(["z", "z_1"]), name="subnet") """ Validator.check_element_type_of_iterable("names", names, [str], "ScopedValue") if scopes is not None: diff --git a/mindspore/python/mindspore/rewrite/api/symbol_tree.py b/mindspore/python/mindspore/rewrite/api/symbol_tree.py index c437aac2507..bc731b35072 100644 --- a/mindspore/python/mindspore/rewrite/api/symbol_tree.py +++ b/mindspore/python/mindspore/rewrite/api/symbol_tree.py @@ -99,6 +99,14 @@ class SymbolTree: TypeError: If the type of `targets` is not str. TypeError: If arg in `args` is not ParamType. TypeError: If key of `kwarg` is not a str or value of kwarg in `kwargs` is not ParamType. + + Examples: + >>> from mindspore.rewrite import SymbolTree + >>> from lenet import Lenet + >>> net = Lenet() + >>> stree = SymbolTree.create(net) + >>> node = stree.get_node("conv1") + >>> new_node = stree.create_call_function(F.abs, ["x"], node) """ Validator.check_value_type("func", func, [FunctionType], "SymbolTree node") Validator.check_element_type_of_iterable("targets", targets, [str], "SymbolTree node") @@ -119,6 +127,13 @@ class SymbolTree: Returns: An instance of `SymbolTree`. + + Examples: + >>> from mindspore.rewrite import SymbolTree + >>> from lenet import Lenet + >>> net = Lenet() + >>> stree = SymbolTree.create(net) + >>> handler = stree.get_handler() """ return self._symbol_tree @@ -128,6 +143,14 @@ class SymbolTree: Returns: A generator for node of current `SymbolTree`. + + Examples: + >>> from mindspore.rewrite import SymbolTree + >>> from lenet import Lenet + >>> net = Lenet() + >>> stree = SymbolTree.create(net) + >>> for node in stree.nodes(): + ... node.set_attribute("channel", 3) """ for node in self._symbol_tree.nodes(): yield Node(node) @@ -144,6 +167,13 @@ class SymbolTree: Raises: TypeError: If `node_name` is not `str`. + + Examples: + >>> from mindspore.rewrite import SymbolTree + >>> from lenet import Lenet + >>> net = Lenet() + >>> stree = SymbolTree.create(net) + >>> node = stree.get_node("conv1") """ Validator.check_value_type("node_name", node_name, [str], "SymbolTree") node_impl = self._symbol_tree.get_node(node_name) @@ -157,6 +187,13 @@ class SymbolTree: Returns: [Node], the node list of the current `Symboltree`. + + Examples: + >>> from mindspore.rewrite import SymbolTree + >>> from lenet import Lenet + >>> net = Lenet() + >>> stree = SymbolTree.create(net) + >>> inputs = stree.get_inputs() """ return [Node(node_impl) for node_impl in self._symbol_tree.get_inputs()] @@ -176,6 +213,15 @@ class SymbolTree: Raises: TypeError: if `node` is not a `Node`. + + Examples: + >>> from mindspore.rewrite import SymbolTree + >>> from lenet import Lenet + >>> net = Lenet() + >>> stree = SymbolTree.create(net) + >>> for node in stree.nodes(): + ... if node.get_name() == "conv1": + ... position = stree.before(node) """ Validator.check_value_type("node", node, [Node], "SymbolTree") return self._symbol_tree.before(node.get_handler()) @@ -196,6 +242,15 @@ class SymbolTree: Raises: TypeError: If `node` is not a `Node`. + + Examples: + >>> from mindspore.rewrite import SymbolTree + >>> from lenet import Lenet + >>> net = Lenet() + >>> stree = SymbolTree.create(net) + >>> for node in stree.nodes(): + ... if node.get_name() == "conv1": + ... position = stree.after(node) """ Validator.check_value_type("node", node, [Node], "SymbolTree") return self._symbol_tree.after(node.get_handler()) @@ -218,6 +273,16 @@ class SymbolTree: RuntimeError: If `position` is not belong to current `SymbolTree`. TypeError: If `position` is not a `Position`. TypeError: If `node` is not a `Node`. + + Examples: + >>> from mindspore.rewrite import SymbolTree + >>> from lenet import Lenet + >>> net = Lenet() + >>> stree = SymbolTree.create(net) + >>> node = stree.get_node("conv1") + >>> position = stree.after(node) + >>> new_node = stree.create_call_function(F.abs, ["x"], node) + >>> stree.insert(position, new_node) """ Validator.check_value_type("position", position, [Position], "SymbolTree") Validator.check_value_type("node", node, [Node], "SymbolTree") @@ -235,6 +300,18 @@ class SymbolTree: Raises: TypeError: If `node` is not a `Node`. + + Examples: + >>> from mindspore.rewrite import SymbolTree + >>> from lenet import Lenet + >>> net = Lenet() + >>> stree = SymbolTree.create(net) + >>> node = stree.get_node("conv1") + >>> input_node = node.get_inputs()[0] + >>> output_nodes = node.get_users() + >>> for n in output_nodes: + ... n.set_arg(0, "x") + >>> stree.erase_node(node) """ Validator.check_value_type("node", node, [Node], "SymbolTree") return Node(self._symbol_tree.erase_node(node.get_handler())) @@ -264,9 +341,18 @@ class SymbolTree: An instance of Node represents root of node_tree been replaced in. Raises: - RuntimeError: Old node is isolated. + RuntimeError: Old node is not isolated. TypeError: If `old_node` is not a `Node`. TypeError: If `new_nodes` is not a `list` or node in `new_nodes` is not a `Node`. + + Examples: + >>> from mindspore.rewrite import SymbolTree + >>> from lenet import Lenet + >>> net = Lenet() + >>> stree = SymbolTree.create(net) + >>> node = stree.get_node("conv1") + >>> new_node = stree.create_call_function(F.abs, ["x"], node) + >>> stree.replace(node, [new_node]) """ Validator.check_value_type("old_node", old_node, [Node], "SymbolTree") Validator.check_element_type_of_iterable("new_nodes", new_nodes, [Node], "SymbolTree") @@ -288,6 +374,13 @@ class SymbolTree: RuntimeError: If `index` is out of range. TypeError: If `index` is not a `int` number. TypeError: If `return_value` is not a `str`. + + Examples: + >>> from mindspore.rewrite import SymbolTree + >>> from lenet import Lenet + >>> net = Lenet() + >>> stree = SymbolTree.create(net) + >>> stree.set_output(0, "x_10") """ Validator.check_value_type("index", index, [int], "SymbolTree") Validator.check_value_type("return_value", return_value, [str], "SymbolTree") @@ -295,13 +388,20 @@ class SymbolTree: def dump(self): """ - Dump graph to console. + Print the ir map information corresponding to the network in 'SymbolTree' to the screen. """ self._symbol_tree.dump() def print_node_tabulate(self): """ Print node information of graph. + + Examples: + >>> from mindspore.rewrite import SymbolTree + >>> from lenet import Lenet + >>> net = Lenet() + >>> stree = SymbolTree.create(net) + >>> stree.print_node_tabulate() """ self._symbol_tree.print_node_tabulate() @@ -311,6 +411,13 @@ class SymbolTree: Returns: A str represents source code of modified network. + + Examples: + >>> from mindspore.rewrite import SymbolTree + >>> from lenet import Lenet + >>> net = Lenet() + >>> stree = SymbolTree.create(net) + >>> stree.get_code() """ return self._symbol_tree.get_code() @@ -321,6 +428,13 @@ class SymbolTree: Returns: A network object. + + Examples: + >>> from mindspore.rewrite import SymbolTree + >>> from lenet import Lenet + >>> net = Lenet() + >>> stree = SymbolTree.create(net) + >>> stree.get_network() """ return self._symbol_tree.get_network() @@ -330,16 +444,40 @@ class SymbolTree: Args: file_name (str): filename to be set. + + Examples: + >>> from mindspore.rewrite import SymbolTree + >>> from lenet import Lenet + >>> net = Lenet() + >>> stree = SymbolTree.create(net) + >>> stree.set_saved_file_name("new_net") """ Validator.check_value_type("file_name", file_name, [str], "Saving network") self._symbol_tree.set_saved_file_name(file_name) def get_saved_file_name(self): - """Gets the filename used to save the network.""" + """ + Gets the filename used to save the network. + + Examples: + >>> from mindspore.rewrite import SymbolTree + >>> from lenet import Lenet + >>> net = Lenet() + >>> stree = SymbolTree.create(net)x + >>> stree.set_saved_file_name("new_net") + >>> stree.get_saved_file_name() + """ return self._symbol_tree.get_saved_file_name() def save_network_to_file(self): """ Save the modified network to a file. Default file name is `network_define.py`. + + Examples: + >>> from mindspore.rewrite import SymbolTree + >>> from lenet import Lenet + >>> net = Lenet() + >>> stree = SymbolTree.create(net) + >>> stree.save_network_to_file() """ self._symbol_tree.save_network_to_file() diff --git a/mindspore/python/mindspore/rewrite/api/tree_node_helper.py b/mindspore/python/mindspore/rewrite/api/tree_node_helper.py index faf3859d6fe..e6da92d9df0 100644 --- a/mindspore/python/mindspore/rewrite/api/tree_node_helper.py +++ b/mindspore/python/mindspore/rewrite/api/tree_node_helper.py @@ -36,7 +36,7 @@ class TreeNodeHelper: Getting symbol_tree from a `Tree` type `Node`. Args: - node (Node): A `Node` who may hold a sub-symbol_tree. + node (Node): A `Node` which may hold a sub-symbol_tree. Returns: An instance of SymbolTree represents sub-symbol_tree. Note that `node`'s symbol_tree maybe None, in this