forked from mindspore-Ecosystem/mindspore
!32363 add AstFinder, AstReplacer, Observer, Observable and ClassNamer in Rewrite
Merge pull request !32363 from hangq/mscompression-pr-dev
This commit is contained in:
commit
feac726ec7
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
MindSpore Rewrite module.
|
||||
MindSpore Rewrite package.
|
||||
This is an experimental python package that is subject to change or deletion.
|
||||
"""
|
||||
from .parsers.module_parser import g_module_parser
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
`ast_helpers` package of MindSpore Rewrite package.
|
||||
Define some ast helpers for manipulating python ast.
|
||||
"""
|
||||
|
||||
from .ast_finder import AstFinder
|
||||
from .ast_replacer import AstReplacer
|
||||
from .ast_modifier import AstModifier
|
||||
|
||||
__all__ = ["AstFinder", "AstReplacer", "AstModifier"]
|
|
@ -0,0 +1,70 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Find specific type ast node in specific scope."""
|
||||
|
||||
from typing import Type
|
||||
import ast
|
||||
|
||||
|
||||
class AstFinder(ast.NodeVisitor):
|
||||
"""
|
||||
Find all specific type ast node in specific scope.
|
||||
|
||||
Args:
|
||||
node (ast.AST): An instance of ast node as search scope.
|
||||
"""
|
||||
|
||||
def __init__(self, node: ast.AST):
|
||||
self._scope: ast.AST = node
|
||||
self._targets: tuple = ()
|
||||
self._results: [ast.AST] = []
|
||||
|
||||
def generic_visit(self, node):
|
||||
"""
|
||||
An override method, iterating over all nodes and save target ast nodes.
|
||||
|
||||
|
||||
Args:
|
||||
node (ast.AST): An instance of ast node which is visited currently.
|
||||
"""
|
||||
|
||||
if isinstance(node, self._targets):
|
||||
self._results.append(node)
|
||||
super(AstFinder, self).generic_visit(node)
|
||||
|
||||
def find_all(self, ast_types) -> [ast.AST]:
|
||||
"""
|
||||
Find all matched ast node.
|
||||
|
||||
Args:
|
||||
ast_types (Union[tuple(Type), Type]): A tuple of Type or a Type indicates target ast node type.
|
||||
|
||||
Returns:
|
||||
A list of instance of ast.AST as matched result.
|
||||
|
||||
Raises:
|
||||
ValueError: If input `ast_types` is not a type nor a tuple.
|
||||
"""
|
||||
|
||||
if isinstance(ast_types, Type):
|
||||
self._targets: tuple = (ast_types,)
|
||||
else:
|
||||
if not isinstance(ast_types, tuple):
|
||||
raise ValueError("Input ast_types should be a tuple or a type")
|
||||
self._targets: tuple = ast_types
|
||||
|
||||
self._results.clear()
|
||||
self.visit(self._scope)
|
||||
return self._results
|
|
@ -16,7 +16,7 @@
|
|||
from typing import Optional
|
||||
import ast
|
||||
|
||||
from .api.scoped_value import ScopedValue, ValueType
|
||||
from ..api.scoped_value import ScopedValue, ValueType
|
||||
|
||||
|
||||
class AstModifier(ast.NodeTransformer):
|
||||
|
@ -40,6 +40,50 @@ class AstModifier(ast.NodeTransformer):
|
|||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def insert_sub_ast(ast_father: ast.AST, ast_son: ast.AST, index_ast: Optional[ast.AST] = None,
|
||||
insert_before=True) -> ast.AST:
|
||||
"""
|
||||
Insert an ast node into another ast node's body.
|
||||
|
||||
Args:
|
||||
ast_father (ast.AST): Where new ast node to be inserted into.
|
||||
ast_son (ast.AST): An ast node to be inserted in.
|
||||
index_ast (Optional[ast.AST]): An ast_node indicates a position in 'ast_father' where new ast node to be
|
||||
inserted into. Default is None which means append new ast node to body of
|
||||
'ast_father'.
|
||||
insert_before (bool): A bool indicates at before or at after of 'index_ast' where new ast node to be
|
||||
inserted into. Only valid when 'index_ast' is not None. Default is True which means
|
||||
inserting new ast node before 'index_ast'.
|
||||
|
||||
Returns:
|
||||
An instance of ast.AST which has been inserted into 'ast_father'.
|
||||
|
||||
Raises:
|
||||
ValueError: If 'ast_father' has no attribute named 'body'.
|
||||
RuntimeError: If 'index_ast' is not contained in 'ast_father'.
|
||||
"""
|
||||
if not hasattr(ast_father, "body"):
|
||||
raise ValueError("Input ast_father has no attribute body:", type(ast_father))
|
||||
if index_ast is None:
|
||||
ast_father.body.append(ast_son)
|
||||
ast.fix_missing_locations(ast_father)
|
||||
return ast_son
|
||||
for index in range(0, len(ast_father.body)):
|
||||
if id(ast_father.body[index]) == id(index_ast):
|
||||
if insert_before:
|
||||
ast_father.body.insert(index, ast_son)
|
||||
else:
|
||||
ast_father.body.insert(index + 1, ast_son)
|
||||
ast.fix_missing_locations(ast_father)
|
||||
return ast_son
|
||||
raise RuntimeError("index_ast is not contained in ast_father")
|
||||
|
||||
@staticmethod
|
||||
def insert_class_into_module(ast_mod: ast.Module, ast_class: ast.ClassDef, index_ast: Optional[ast.AST] = None,
|
||||
insert_before=True) -> ast.ClassDef:
|
||||
return AstModifier.insert_sub_ast(ast_mod, ast_class, index_ast, insert_before)
|
||||
|
||||
@staticmethod
|
||||
def insert_assign_to_function(ast_func: ast.FunctionDef, targets: [ScopedValue], expr: ScopedValue,
|
||||
args: [ScopedValue] = None, kwargs: {str, ScopedValue}=None,
|
||||
|
@ -162,6 +206,77 @@ class AstModifier(ast.NodeTransformer):
|
|||
ast.fix_missing_locations(result)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _create_call_args(args: [ScopedValue]) -> [ast.AST]:
|
||||
"""
|
||||
Create a list of ast.AST as args of ast.Call from a list of `ScopedValue`.
|
||||
|
||||
Args:
|
||||
args (list[ScopedValue]): Args of ast.Call.
|
||||
|
||||
Returns:
|
||||
A list of ast.AST as args of ast.Call.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If element of 'args' is not an instance of `ScopedValue`.
|
||||
RuntimeError: If value_type of element of 'args' is `ValueType.CustomObjValue`.
|
||||
"""
|
||||
|
||||
if args is None:
|
||||
return []
|
||||
results = []
|
||||
for arg in args:
|
||||
if not isinstance(arg, ScopedValue):
|
||||
raise TypeError("arg should be ScopedValue, got: ", type(arg))
|
||||
if arg.type in (ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue):
|
||||
if arg.scope:
|
||||
raise RuntimeError("arg.scope should be empty")
|
||||
results.append(ast.Constant(value=arg.value, kind=None))
|
||||
elif arg.type == ValueType.NamingValue:
|
||||
if arg.scope:
|
||||
results.append(ast.Attribute(ast.Name(arg.scope, ast.Load()), arg.value, ast.Store()))
|
||||
else:
|
||||
results.append(ast.Name(arg.value, ast.Store()))
|
||||
else:
|
||||
raise RuntimeError("Please handle custom-object first")
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def _create_call_kwargs(kwargs: {str: ScopedValue}) -> [ast.keyword]:
|
||||
"""
|
||||
Create a list of ast.keyword as kwargs of ast.Call from a dict of string to `ScopedValue`.
|
||||
|
||||
Args:
|
||||
kwargs (dict{str: ScopedValue}): Kwargs of ast.Call.
|
||||
|
||||
Returns:
|
||||
A list of ast.AST as args of ast.Call.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If element of 'args' is not an instance of `ScopedValue`.
|
||||
RuntimeError: If value_type of element of 'args' is `ValueType.CustomObjValue`.
|
||||
"""
|
||||
|
||||
if kwargs is None:
|
||||
return []
|
||||
results = []
|
||||
for arg, value in kwargs.items():
|
||||
if not isinstance(value, ScopedValue):
|
||||
raise TypeError("value should be ScopedValue, got: ", type(value))
|
||||
if value.type in (ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue):
|
||||
if value.scope:
|
||||
raise RuntimeError("value.scope should be empty")
|
||||
results.append(ast.keyword(arg=arg, value=ast.Constant(value=value.value, kind=None)))
|
||||
elif value.type == ValueType.NamingValue:
|
||||
if value.scope:
|
||||
results.append(ast.keyword(arg=arg, value=ast.Attribute(ast.Name(value.scope, ast.Load()),
|
||||
value.value, ast.Store())))
|
||||
else:
|
||||
results.append(ast.keyword(arg=arg, value=ast.Name(value.value, ast.Store())))
|
||||
else:
|
||||
raise RuntimeError("Please handle custom-object first")
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def create_call(expr: ScopedValue, args: [ScopedValue] = None, kwargs: {str: ScopedValue}=None) -> ast.Call:
|
||||
"""
|
||||
|
@ -178,11 +293,7 @@ class AstModifier(ast.NodeTransformer):
|
|||
Raises:
|
||||
RuntimeError: If value_type of 'expr' is ValueType.CustomObjValue.
|
||||
RuntimeError: If value_type of 'expr' is not ValueType.NamingValue.
|
||||
RuntimeError: If value_type of element of 'args' is ValueType.CustomObjValue.
|
||||
RuntimeError: If value_type of value of 'kwargs' is ValueType.CustomObjValue.
|
||||
TypeError: If expr is not an instance of ScopedValue.
|
||||
RuntimeError: If element of 'args' is not an instance of ScopedValue.
|
||||
RuntimeError: If value of 'kwargs' is not an instance of ScopedValue.
|
||||
"""
|
||||
if not isinstance(expr, ScopedValue):
|
||||
raise TypeError("expr should be ScopedValue, got: ", type(expr))
|
||||
|
@ -195,40 +306,8 @@ class AstModifier(ast.NodeTransformer):
|
|||
else:
|
||||
ast_func = ast.Name(expr.value, ast.Store())
|
||||
|
||||
ast_args = []
|
||||
if args is not None:
|
||||
for arg in args:
|
||||
if not isinstance(arg, ScopedValue):
|
||||
raise TypeError("arg should be ScopedValue, got: ", type(arg))
|
||||
if arg.type in (ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue):
|
||||
if arg.scope:
|
||||
raise RuntimeError("arg.scope should be empty")
|
||||
ast_args.append(ast.Constant(value=arg.value, kind=None))
|
||||
elif arg.type == ValueType.NamingValue:
|
||||
if arg.scope:
|
||||
ast_args.append(ast.Attribute(ast.Name(arg.scope, ast.Load()), arg.value, ast.Store()))
|
||||
else:
|
||||
ast_args.append(ast.Name(arg.value, ast.Store()))
|
||||
else:
|
||||
raise RuntimeError("Please handle custom-object first")
|
||||
keywords = []
|
||||
if kwargs is not None:
|
||||
for arg, value in kwargs.items():
|
||||
if not isinstance(value, ScopedValue):
|
||||
raise TypeError("value should be ScopedValue, got: ", type(value))
|
||||
if value.type in (ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue):
|
||||
if value.scope:
|
||||
raise RuntimeError("value.scope should be empty")
|
||||
keywords.append(ast.keyword(arg=arg, value=ast.Constant(value=value.value, kind=None)))
|
||||
elif value.type == ValueType.NamingValue:
|
||||
if value.scope:
|
||||
keywords.append(ast.keyword(arg=arg,
|
||||
value=ast.Attribute(ast.Name(value.scope, ast.Load()), value.value,
|
||||
ast.Store())))
|
||||
else:
|
||||
keywords.append(ast.keyword(arg=arg, value=ast.Name(value.value, ast.Store())))
|
||||
else:
|
||||
raise RuntimeError("Please handle custom-object first")
|
||||
ast_args = AstModifier._create_call_args(args)
|
||||
keywords = AstModifier._create_call_kwargs(kwargs)
|
||||
result = ast.Call(func=ast_func, args=ast_args, keywords=keywords)
|
||||
ast.fix_missing_locations(result)
|
||||
return result
|
|
@ -0,0 +1,79 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Replacing specific symbol name with another symbol name in specific scope."""
|
||||
|
||||
from typing import Any
|
||||
import ast
|
||||
|
||||
|
||||
class AstReplacer(ast.NodeTransformer):
|
||||
"""
|
||||
Replace all specific symbol name in specific scope with another symbol name.
|
||||
|
||||
Args:
|
||||
node (ast.AST): An instance of ast node as replace scope.
|
||||
"""
|
||||
|
||||
def __init__(self, node: ast.AST):
|
||||
self._scope = node
|
||||
self._src = ""
|
||||
self._dst = ""
|
||||
self._trace = []
|
||||
|
||||
def visit_ClassDef(self, node: ast.ClassDef) -> Any:
|
||||
"""
|
||||
An override method, call back when visiting an ast.ClassDef node.
|
||||
|
||||
Args:
|
||||
node (ast.ClassDef): An instance of ast.ClassDef which is visited currently.
|
||||
"""
|
||||
|
||||
if node.name == self._src:
|
||||
node.name = self._dst
|
||||
self._trace.append((node, "name", self._src, self._dst))
|
||||
return self.generic_visit(node)
|
||||
|
||||
def visit_Name(self, node: ast.Name) -> Any:
|
||||
"""
|
||||
An override method, call back when visiting an ast.Name node.
|
||||
|
||||
Args:
|
||||
node (ast.Name): An instance of ast.Name which is visited currently.
|
||||
"""
|
||||
|
||||
if node.id == self._src:
|
||||
node.id = self._dst
|
||||
self._trace.append((node, "id", self._src, self._dst))
|
||||
return self.generic_visit(node)
|
||||
|
||||
def replace_all(self, src: str, dst: str):
|
||||
"""
|
||||
Replace all matched symbol to new symbol name.
|
||||
|
||||
Args:
|
||||
src (str): Target symbol name to be replaced out.
|
||||
dst (str): New symbol name to be replaced in.
|
||||
"""
|
||||
|
||||
self._src = src
|
||||
self._dst = dst
|
||||
self.visit(self._scope)
|
||||
|
||||
def undo_all(self):
|
||||
"""Undo all replace-actions applied on current scope."""
|
||||
|
||||
for trace in self._trace:
|
||||
setattr(trace[0], trace[1], trace[2])
|
||||
self._trace.clear()
|
|
@ -0,0 +1,18 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
`common` package of MindSpore Rewrite package.
|
||||
Define some common instruments.
|
||||
"""
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Base class, observable of observer design pattern."""
|
||||
|
||||
from .observer import Observer
|
||||
|
||||
|
||||
class Observable:
|
||||
"""Abstract class, observable of observer design pattern."""
|
||||
|
||||
def __init__(self):
|
||||
self._observers: [Observer] = []
|
||||
|
||||
def changed(self):
|
||||
"""
|
||||
Called when current observable is changed.
|
||||
`Observable` declares a change and all registered observers observe a change and do something for this change.
|
||||
"""
|
||||
|
||||
for observer in self._observers:
|
||||
observer.on_change()
|
||||
|
||||
def reg_observer(self, observer: Observer):
|
||||
"""
|
||||
Register an `observer` so that it can observe changes of current observable.
|
||||
|
||||
Args:
|
||||
observer (Observer): An `Observer` to be registered into current observable.
|
||||
"""
|
||||
|
||||
self._observers.append(observer)
|
|
@ -0,0 +1,54 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Abstract class, observer of observer design pattern."""
|
||||
|
||||
import abc
|
||||
|
||||
|
||||
class Observer(abc.ABC):
|
||||
"""Abstract class, observer of observer design pattern."""
|
||||
|
||||
def __init__(self):
|
||||
self._observing = False
|
||||
|
||||
def start_observe(self):
|
||||
"""
|
||||
Start observing so that current `Observer` can do response when any change occurred in `Observable`.
|
||||
"""
|
||||
|
||||
self._observing = True
|
||||
|
||||
def stop_observe(self):
|
||||
"""
|
||||
Stop observing so that current `Observer` will do nothing even when changes occurred in linked `Observable`.
|
||||
"""
|
||||
|
||||
self._observing = False
|
||||
|
||||
def on_change(self):
|
||||
"""
|
||||
Called back when any changes occurred in linked `Observable`.
|
||||
"""
|
||||
|
||||
if self._observing:
|
||||
self._on_change()
|
||||
|
||||
@abc.abstractmethod
|
||||
def _on_change(self):
|
||||
"""
|
||||
Abstract method for defining how to response when any changes occurred in linked `Observable`.
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
|
@ -12,7 +12,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Unique name producer for target, name of node."""
|
||||
"""Unique name producer for target, name of node, class name, etc."""
|
||||
|
||||
from typing import Union
|
||||
|
||||
from .node import Node
|
||||
|
@ -135,7 +136,7 @@ class NodeNamer(Namer):
|
|||
if isinstance(node_or_name, Node):
|
||||
origin_name = node_or_name.get_name()
|
||||
if origin_name is None or not origin_name:
|
||||
if node_or_name.get_node_type() == NodeType.CallCell:
|
||||
if node_or_name.get_node_type() in (NodeType.CallCell, NodeType.CallPrimitive):
|
||||
if not isinstance(node_or_name, Node):
|
||||
raise TypeError("node_or_name should be Node, got: ", type(node_or_name))
|
||||
targets = node_or_name.get_targets()
|
||||
|
@ -157,3 +158,52 @@ class NodeNamer(Namer):
|
|||
else:
|
||||
raise RuntimeError("unexpected type of node_or_name: ", type(node_or_name))
|
||||
return super(NodeNamer, self).get_name(origin_name)
|
||||
|
||||
|
||||
class ClassNamer(Namer):
|
||||
"""
|
||||
Used for unique-ing class name in a network.
|
||||
|
||||
Class name should be unique in a network, in other word, in a Rewrite process. So please do not invoke constructor
|
||||
of `ClassNamer` and call `instance()` of `ClassNamer` to obtain singleton of ClassNamer.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._prefix = "Opt"
|
||||
|
||||
@classmethod
|
||||
def instance(cls):
|
||||
"""
|
||||
Class method of `ClassNamer` for singleton of `ClassNamer`.
|
||||
|
||||
Returns:
|
||||
An instance of `ClassNamer` as singleton of `ClassNamer`.
|
||||
"""
|
||||
|
||||
if not hasattr(ClassNamer, "_instance"):
|
||||
ClassNamer._instance = ClassNamer()
|
||||
return ClassNamer._instance
|
||||
|
||||
def get_name(self, origin_class_name: str) -> str:
|
||||
"""
|
||||
Unique input `origin_class_name`.
|
||||
|
||||
Args:
|
||||
origin_class_name (str): A string represents original class name.
|
||||
|
||||
Returns:
|
||||
A string represents a unique class name generated from `origin_class_name`.
|
||||
"""
|
||||
|
||||
return super(ClassNamer, self).get_name(origin_class_name + self._prefix)
|
||||
|
||||
def add_name(self, class_name: str):
|
||||
"""
|
||||
Declare a `class_name` so that other class can not apply this `class_name` anymore.
|
||||
|
||||
Args:
|
||||
class_name (str): A string represents a class name.
|
||||
"""
|
||||
|
||||
super(ClassNamer, self).add_name(class_name + self._prefix)
|
||||
|
|
|
@ -19,7 +19,7 @@ import inspect
|
|||
|
||||
from mindspore.nn import Cell
|
||||
from mindspore import log as logger
|
||||
from .ast_modifier import AstModifier
|
||||
from .ast_helpers import AstModifier
|
||||
from .api.scoped_value import ScopedValue, ValueType
|
||||
from .api.node_type import NodeType
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ from ..symbol_tree import SymbolTree
|
|||
from ..parser import Parser
|
||||
from ..parser_register import ParserRegister, reg_parser
|
||||
from ..api.scoped_value import ScopedValue
|
||||
from ..ast_modifier import AstModifier
|
||||
from ..ast_helpers import AstModifier
|
||||
|
||||
|
||||
class ClassDefParser(Parser):
|
||||
|
|
|
@ -24,7 +24,7 @@ from mindspore.nn import Cell
|
|||
from mindspore import log as logger
|
||||
from .node import Node, TreeNode
|
||||
from .api.node_type import NodeType
|
||||
from .ast_modifier import AstModifier
|
||||
from .ast_helpers import AstModifier
|
||||
from .api.scoped_value import ScopedValue, ValueType
|
||||
from .symbol_tree_dumper import SymbolTreeDumper
|
||||
from .topological_manager import TopoManager
|
||||
|
|
|
@ -0,0 +1,69 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import ast
|
||||
import inspect
|
||||
from mindspore import nn
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.rewrite.ast_helpers import AstFinder
|
||||
|
||||
|
||||
class SimpleNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(SimpleNet, self).__init__()
|
||||
self.aaa = 1
|
||||
self.bbb = F.add(1, 1)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.aaa + x
|
||||
x = self.bbb + x
|
||||
return x
|
||||
|
||||
|
||||
def test_finder_single_type():
|
||||
"""
|
||||
Feature: Class AstFinder in Package rewrite.
|
||||
Description: Use AstFinder to find all Assign ast node.
|
||||
Expectation: AstFinder can find all Assign ast node.
|
||||
"""
|
||||
ast_root = ast.parse(inspect.getsource(SimpleNet))
|
||||
finder = AstFinder(ast_root)
|
||||
results = finder.find_all(ast.Assign)
|
||||
assert len(results) == 4
|
||||
for result in results:
|
||||
assert isinstance(result, ast.Assign)
|
||||
|
||||
|
||||
def test_finder_multi_type():
|
||||
"""
|
||||
Feature: Class AstFinder in Package rewrite.
|
||||
Description: Use AstFinder to find all Assign and Attribute ast node.
|
||||
Expectation: AstFinder can find all Assign and Attribute ast node.
|
||||
"""
|
||||
ast_root = ast.parse(inspect.getsource(SimpleNet))
|
||||
finder = AstFinder(ast_root)
|
||||
results = finder.find_all((ast.Assign, ast.Attribute))
|
||||
assert len(results) == 11
|
||||
assign_num = 0
|
||||
attribute_num = 0
|
||||
for result in results:
|
||||
if isinstance(result, ast.Assign):
|
||||
assign_num += 1
|
||||
continue
|
||||
if isinstance(result, ast.Attribute):
|
||||
attribute_num += 1
|
||||
continue
|
||||
assert False
|
||||
assert assign_num == 4
|
||||
assert attribute_num == 7
|
|
@ -0,0 +1,73 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import ast
|
||||
import re
|
||||
import inspect
|
||||
import astunparse
|
||||
from mindspore import nn
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.rewrite.ast_helpers import AstReplacer
|
||||
|
||||
|
||||
class SimpleNet2(nn.Cell):
|
||||
def construct(self, x):
|
||||
return F.add(x, x)
|
||||
|
||||
|
||||
class SimpleNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(SimpleNet, self).__init__()
|
||||
SimpleNet._get_int()
|
||||
self.aaa = SimpleNet._get_int()
|
||||
self.bbb = SimpleNet._get_int() + 1
|
||||
self.ccc = F.add(SimpleNet._get_int(), 1)
|
||||
self.ddd = SimpleNet2()
|
||||
|
||||
@staticmethod
|
||||
def _get_int():
|
||||
return 1
|
||||
|
||||
def construct(self, x):
|
||||
SimpleNet._get_int()
|
||||
aaa = SimpleNet._get_int()
|
||||
bbb = SimpleNet._get_int() + aaa
|
||||
ccc = F.add(SimpleNet._get_int(), bbb)
|
||||
x = self.ddd(ccc)
|
||||
return x
|
||||
|
||||
|
||||
def test_replacer():
|
||||
"""
|
||||
Feature: Class AstReplacer in Package rewrite.
|
||||
Description:
|
||||
Use AstReplacer to replace all "SimpleNet" symbol to "SimpleNet2" symbol.
|
||||
Use AstReplacer to undo all replace.
|
||||
Expectation: AstReplacer can replace all "SimpleNet" symbol to "SimpleNet2" symbol and restore original ast node.
|
||||
"""
|
||||
|
||||
original_code = inspect.getsource(SimpleNet)
|
||||
assert len(re.findall("SimpleNet", original_code)) == 11
|
||||
assert len(re.findall("SimpleNet2", original_code)) == 1
|
||||
|
||||
ast_root = ast.parse(original_code)
|
||||
replacer = AstReplacer(ast_root)
|
||||
replacer.replace_all("SimpleNet", "SimpleNet2")
|
||||
replaced_code = astunparse.unparse(ast_root)
|
||||
assert len(re.findall("SimpleNet", replaced_code)) == 11
|
||||
assert len(re.findall("SimpleNet2", replaced_code)) == 11
|
||||
|
||||
replacer.undo_all()
|
||||
assert len(re.findall("SimpleNet", original_code)) == 11
|
||||
assert len(re.findall("SimpleNet2", original_code)) == 1
|
Loading…
Reference in New Issue