!32363 add AstFinder, AstReplacer, Observer, Observable and ClassNamer in Rewrite

Merge pull request !32363 from hangq/mscompression-pr-dev
This commit is contained in:
i-robot 2022-04-02 05:04:37 +00:00 committed by Gitee
commit feac726ec7
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
14 changed files with 604 additions and 45 deletions

View File

@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================
"""
MindSpore Rewrite module.
MindSpore Rewrite package.
This is an experimental python package that is subject to change or deletion.
"""
from .parsers.module_parser import g_module_parser

View File

@ -0,0 +1,24 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
`ast_helpers` package of MindSpore Rewrite package.
Define some ast helpers for manipulating python ast.
"""
from .ast_finder import AstFinder
from .ast_replacer import AstReplacer
from .ast_modifier import AstModifier
__all__ = ["AstFinder", "AstReplacer", "AstModifier"]

View File

@ -0,0 +1,70 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Find specific type ast node in specific scope."""
from typing import Type
import ast
class AstFinder(ast.NodeVisitor):
"""
Find all specific type ast node in specific scope.
Args:
node (ast.AST): An instance of ast node as search scope.
"""
def __init__(self, node: ast.AST):
self._scope: ast.AST = node
self._targets: tuple = ()
self._results: [ast.AST] = []
def generic_visit(self, node):
"""
An override method, iterating over all nodes and save target ast nodes.
Args:
node (ast.AST): An instance of ast node which is visited currently.
"""
if isinstance(node, self._targets):
self._results.append(node)
super(AstFinder, self).generic_visit(node)
def find_all(self, ast_types) -> [ast.AST]:
"""
Find all matched ast node.
Args:
ast_types (Union[tuple(Type), Type]): A tuple of Type or a Type indicates target ast node type.
Returns:
A list of instance of ast.AST as matched result.
Raises:
ValueError: If input `ast_types` is not a type nor a tuple.
"""
if isinstance(ast_types, Type):
self._targets: tuple = (ast_types,)
else:
if not isinstance(ast_types, tuple):
raise ValueError("Input ast_types should be a tuple or a type")
self._targets: tuple = ast_types
self._results.clear()
self.visit(self._scope)
return self._results

View File

@ -16,7 +16,7 @@
from typing import Optional
import ast
from .api.scoped_value import ScopedValue, ValueType
from ..api.scoped_value import ScopedValue, ValueType
class AstModifier(ast.NodeTransformer):
@ -40,6 +40,50 @@ class AstModifier(ast.NodeTransformer):
return True
return False
@staticmethod
def insert_sub_ast(ast_father: ast.AST, ast_son: ast.AST, index_ast: Optional[ast.AST] = None,
insert_before=True) -> ast.AST:
"""
Insert an ast node into another ast node's body.
Args:
ast_father (ast.AST): Where new ast node to be inserted into.
ast_son (ast.AST): An ast node to be inserted in.
index_ast (Optional[ast.AST]): An ast_node indicates a position in 'ast_father' where new ast node to be
inserted into. Default is None which means append new ast node to body of
'ast_father'.
insert_before (bool): A bool indicates at before or at after of 'index_ast' where new ast node to be
inserted into. Only valid when 'index_ast' is not None. Default is True which means
inserting new ast node before 'index_ast'.
Returns:
An instance of ast.AST which has been inserted into 'ast_father'.
Raises:
ValueError: If 'ast_father' has no attribute named 'body'.
RuntimeError: If 'index_ast' is not contained in 'ast_father'.
"""
if not hasattr(ast_father, "body"):
raise ValueError("Input ast_father has no attribute body:", type(ast_father))
if index_ast is None:
ast_father.body.append(ast_son)
ast.fix_missing_locations(ast_father)
return ast_son
for index in range(0, len(ast_father.body)):
if id(ast_father.body[index]) == id(index_ast):
if insert_before:
ast_father.body.insert(index, ast_son)
else:
ast_father.body.insert(index + 1, ast_son)
ast.fix_missing_locations(ast_father)
return ast_son
raise RuntimeError("index_ast is not contained in ast_father")
@staticmethod
def insert_class_into_module(ast_mod: ast.Module, ast_class: ast.ClassDef, index_ast: Optional[ast.AST] = None,
insert_before=True) -> ast.ClassDef:
return AstModifier.insert_sub_ast(ast_mod, ast_class, index_ast, insert_before)
@staticmethod
def insert_assign_to_function(ast_func: ast.FunctionDef, targets: [ScopedValue], expr: ScopedValue,
args: [ScopedValue] = None, kwargs: {str, ScopedValue}=None,
@ -162,6 +206,77 @@ class AstModifier(ast.NodeTransformer):
ast.fix_missing_locations(result)
return result
@staticmethod
def _create_call_args(args: [ScopedValue]) -> [ast.AST]:
"""
Create a list of ast.AST as args of ast.Call from a list of `ScopedValue`.
Args:
args (list[ScopedValue]): Args of ast.Call.
Returns:
A list of ast.AST as args of ast.Call.
Raises:
RuntimeError: If element of 'args' is not an instance of `ScopedValue`.
RuntimeError: If value_type of element of 'args' is `ValueType.CustomObjValue`.
"""
if args is None:
return []
results = []
for arg in args:
if not isinstance(arg, ScopedValue):
raise TypeError("arg should be ScopedValue, got: ", type(arg))
if arg.type in (ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue):
if arg.scope:
raise RuntimeError("arg.scope should be empty")
results.append(ast.Constant(value=arg.value, kind=None))
elif arg.type == ValueType.NamingValue:
if arg.scope:
results.append(ast.Attribute(ast.Name(arg.scope, ast.Load()), arg.value, ast.Store()))
else:
results.append(ast.Name(arg.value, ast.Store()))
else:
raise RuntimeError("Please handle custom-object first")
return results
@staticmethod
def _create_call_kwargs(kwargs: {str: ScopedValue}) -> [ast.keyword]:
"""
Create a list of ast.keyword as kwargs of ast.Call from a dict of string to `ScopedValue`.
Args:
kwargs (dict{str: ScopedValue}): Kwargs of ast.Call.
Returns:
A list of ast.AST as args of ast.Call.
Raises:
RuntimeError: If element of 'args' is not an instance of `ScopedValue`.
RuntimeError: If value_type of element of 'args' is `ValueType.CustomObjValue`.
"""
if kwargs is None:
return []
results = []
for arg, value in kwargs.items():
if not isinstance(value, ScopedValue):
raise TypeError("value should be ScopedValue, got: ", type(value))
if value.type in (ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue):
if value.scope:
raise RuntimeError("value.scope should be empty")
results.append(ast.keyword(arg=arg, value=ast.Constant(value=value.value, kind=None)))
elif value.type == ValueType.NamingValue:
if value.scope:
results.append(ast.keyword(arg=arg, value=ast.Attribute(ast.Name(value.scope, ast.Load()),
value.value, ast.Store())))
else:
results.append(ast.keyword(arg=arg, value=ast.Name(value.value, ast.Store())))
else:
raise RuntimeError("Please handle custom-object first")
return results
@staticmethod
def create_call(expr: ScopedValue, args: [ScopedValue] = None, kwargs: {str: ScopedValue}=None) -> ast.Call:
"""
@ -178,11 +293,7 @@ class AstModifier(ast.NodeTransformer):
Raises:
RuntimeError: If value_type of 'expr' is ValueType.CustomObjValue.
RuntimeError: If value_type of 'expr' is not ValueType.NamingValue.
RuntimeError: If value_type of element of 'args' is ValueType.CustomObjValue.
RuntimeError: If value_type of value of 'kwargs' is ValueType.CustomObjValue.
TypeError: If expr is not an instance of ScopedValue.
RuntimeError: If element of 'args' is not an instance of ScopedValue.
RuntimeError: If value of 'kwargs' is not an instance of ScopedValue.
"""
if not isinstance(expr, ScopedValue):
raise TypeError("expr should be ScopedValue, got: ", type(expr))
@ -195,40 +306,8 @@ class AstModifier(ast.NodeTransformer):
else:
ast_func = ast.Name(expr.value, ast.Store())
ast_args = []
if args is not None:
for arg in args:
if not isinstance(arg, ScopedValue):
raise TypeError("arg should be ScopedValue, got: ", type(arg))
if arg.type in (ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue):
if arg.scope:
raise RuntimeError("arg.scope should be empty")
ast_args.append(ast.Constant(value=arg.value, kind=None))
elif arg.type == ValueType.NamingValue:
if arg.scope:
ast_args.append(ast.Attribute(ast.Name(arg.scope, ast.Load()), arg.value, ast.Store()))
else:
ast_args.append(ast.Name(arg.value, ast.Store()))
else:
raise RuntimeError("Please handle custom-object first")
keywords = []
if kwargs is not None:
for arg, value in kwargs.items():
if not isinstance(value, ScopedValue):
raise TypeError("value should be ScopedValue, got: ", type(value))
if value.type in (ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue):
if value.scope:
raise RuntimeError("value.scope should be empty")
keywords.append(ast.keyword(arg=arg, value=ast.Constant(value=value.value, kind=None)))
elif value.type == ValueType.NamingValue:
if value.scope:
keywords.append(ast.keyword(arg=arg,
value=ast.Attribute(ast.Name(value.scope, ast.Load()), value.value,
ast.Store())))
else:
keywords.append(ast.keyword(arg=arg, value=ast.Name(value.value, ast.Store())))
else:
raise RuntimeError("Please handle custom-object first")
ast_args = AstModifier._create_call_args(args)
keywords = AstModifier._create_call_kwargs(kwargs)
result = ast.Call(func=ast_func, args=ast_args, keywords=keywords)
ast.fix_missing_locations(result)
return result

View File

@ -0,0 +1,79 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Replacing specific symbol name with another symbol name in specific scope."""
from typing import Any
import ast
class AstReplacer(ast.NodeTransformer):
"""
Replace all specific symbol name in specific scope with another symbol name.
Args:
node (ast.AST): An instance of ast node as replace scope.
"""
def __init__(self, node: ast.AST):
self._scope = node
self._src = ""
self._dst = ""
self._trace = []
def visit_ClassDef(self, node: ast.ClassDef) -> Any:
"""
An override method, call back when visiting an ast.ClassDef node.
Args:
node (ast.ClassDef): An instance of ast.ClassDef which is visited currently.
"""
if node.name == self._src:
node.name = self._dst
self._trace.append((node, "name", self._src, self._dst))
return self.generic_visit(node)
def visit_Name(self, node: ast.Name) -> Any:
"""
An override method, call back when visiting an ast.Name node.
Args:
node (ast.Name): An instance of ast.Name which is visited currently.
"""
if node.id == self._src:
node.id = self._dst
self._trace.append((node, "id", self._src, self._dst))
return self.generic_visit(node)
def replace_all(self, src: str, dst: str):
"""
Replace all matched symbol to new symbol name.
Args:
src (str): Target symbol name to be replaced out.
dst (str): New symbol name to be replaced in.
"""
self._src = src
self._dst = dst
self.visit(self._scope)
def undo_all(self):
"""Undo all replace-actions applied on current scope."""
for trace in self._trace:
setattr(trace[0], trace[1], trace[2])
self._trace.clear()

View File

@ -0,0 +1,18 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
`common` package of MindSpore Rewrite package.
Define some common instruments.
"""

View File

@ -0,0 +1,43 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Base class, observable of observer design pattern."""
from .observer import Observer
class Observable:
"""Abstract class, observable of observer design pattern."""
def __init__(self):
self._observers: [Observer] = []
def changed(self):
"""
Called when current observable is changed.
`Observable` declares a change and all registered observers observe a change and do something for this change.
"""
for observer in self._observers:
observer.on_change()
def reg_observer(self, observer: Observer):
"""
Register an `observer` so that it can observe changes of current observable.
Args:
observer (Observer): An `Observer` to be registered into current observable.
"""
self._observers.append(observer)

View File

@ -0,0 +1,54 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Abstract class, observer of observer design pattern."""
import abc
class Observer(abc.ABC):
"""Abstract class, observer of observer design pattern."""
def __init__(self):
self._observing = False
def start_observe(self):
"""
Start observing so that current `Observer` can do response when any change occurred in `Observable`.
"""
self._observing = True
def stop_observe(self):
"""
Stop observing so that current `Observer` will do nothing even when changes occurred in linked `Observable`.
"""
self._observing = False
def on_change(self):
"""
Called back when any changes occurred in linked `Observable`.
"""
if self._observing:
self._on_change()
@abc.abstractmethod
def _on_change(self):
"""
Abstract method for defining how to response when any changes occurred in linked `Observable`.
"""
raise NotImplementedError

View File

@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Unique name producer for target, name of node."""
"""Unique name producer for target, name of node, class name, etc."""
from typing import Union
from .node import Node
@ -135,7 +136,7 @@ class NodeNamer(Namer):
if isinstance(node_or_name, Node):
origin_name = node_or_name.get_name()
if origin_name is None or not origin_name:
if node_or_name.get_node_type() == NodeType.CallCell:
if node_or_name.get_node_type() in (NodeType.CallCell, NodeType.CallPrimitive):
if not isinstance(node_or_name, Node):
raise TypeError("node_or_name should be Node, got: ", type(node_or_name))
targets = node_or_name.get_targets()
@ -157,3 +158,52 @@ class NodeNamer(Namer):
else:
raise RuntimeError("unexpected type of node_or_name: ", type(node_or_name))
return super(NodeNamer, self).get_name(origin_name)
class ClassNamer(Namer):
"""
Used for unique-ing class name in a network.
Class name should be unique in a network, in other word, in a Rewrite process. So please do not invoke constructor
of `ClassNamer` and call `instance()` of `ClassNamer` to obtain singleton of ClassNamer.
"""
def __init__(self):
super().__init__()
self._prefix = "Opt"
@classmethod
def instance(cls):
"""
Class method of `ClassNamer` for singleton of `ClassNamer`.
Returns:
An instance of `ClassNamer` as singleton of `ClassNamer`.
"""
if not hasattr(ClassNamer, "_instance"):
ClassNamer._instance = ClassNamer()
return ClassNamer._instance
def get_name(self, origin_class_name: str) -> str:
"""
Unique input `origin_class_name`.
Args:
origin_class_name (str): A string represents original class name.
Returns:
A string represents a unique class name generated from `origin_class_name`.
"""
return super(ClassNamer, self).get_name(origin_class_name + self._prefix)
def add_name(self, class_name: str):
"""
Declare a `class_name` so that other class can not apply this `class_name` anymore.
Args:
class_name (str): A string represents a class name.
"""
super(ClassNamer, self).add_name(class_name + self._prefix)

View File

@ -19,7 +19,7 @@ import inspect
from mindspore.nn import Cell
from mindspore import log as logger
from .ast_modifier import AstModifier
from .ast_helpers import AstModifier
from .api.scoped_value import ScopedValue, ValueType
from .api.node_type import NodeType

View File

@ -20,7 +20,7 @@ from ..symbol_tree import SymbolTree
from ..parser import Parser
from ..parser_register import ParserRegister, reg_parser
from ..api.scoped_value import ScopedValue
from ..ast_modifier import AstModifier
from ..ast_helpers import AstModifier
class ClassDefParser(Parser):

View File

@ -24,7 +24,7 @@ from mindspore.nn import Cell
from mindspore import log as logger
from .node import Node, TreeNode
from .api.node_type import NodeType
from .ast_modifier import AstModifier
from .ast_helpers import AstModifier
from .api.scoped_value import ScopedValue, ValueType
from .symbol_tree_dumper import SymbolTreeDumper
from .topological_manager import TopoManager

View File

@ -0,0 +1,69 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import ast
import inspect
from mindspore import nn
from mindspore.ops import functional as F
from mindspore.rewrite.ast_helpers import AstFinder
class SimpleNet(nn.Cell):
def __init__(self):
super(SimpleNet, self).__init__()
self.aaa = 1
self.bbb = F.add(1, 1)
def construct(self, x):
x = self.aaa + x
x = self.bbb + x
return x
def test_finder_single_type():
"""
Feature: Class AstFinder in Package rewrite.
Description: Use AstFinder to find all Assign ast node.
Expectation: AstFinder can find all Assign ast node.
"""
ast_root = ast.parse(inspect.getsource(SimpleNet))
finder = AstFinder(ast_root)
results = finder.find_all(ast.Assign)
assert len(results) == 4
for result in results:
assert isinstance(result, ast.Assign)
def test_finder_multi_type():
"""
Feature: Class AstFinder in Package rewrite.
Description: Use AstFinder to find all Assign and Attribute ast node.
Expectation: AstFinder can find all Assign and Attribute ast node.
"""
ast_root = ast.parse(inspect.getsource(SimpleNet))
finder = AstFinder(ast_root)
results = finder.find_all((ast.Assign, ast.Attribute))
assert len(results) == 11
assign_num = 0
attribute_num = 0
for result in results:
if isinstance(result, ast.Assign):
assign_num += 1
continue
if isinstance(result, ast.Attribute):
attribute_num += 1
continue
assert False
assert assign_num == 4
assert attribute_num == 7

View File

@ -0,0 +1,73 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import ast
import re
import inspect
import astunparse
from mindspore import nn
from mindspore.ops import functional as F
from mindspore.rewrite.ast_helpers import AstReplacer
class SimpleNet2(nn.Cell):
def construct(self, x):
return F.add(x, x)
class SimpleNet(nn.Cell):
def __init__(self):
super(SimpleNet, self).__init__()
SimpleNet._get_int()
self.aaa = SimpleNet._get_int()
self.bbb = SimpleNet._get_int() + 1
self.ccc = F.add(SimpleNet._get_int(), 1)
self.ddd = SimpleNet2()
@staticmethod
def _get_int():
return 1
def construct(self, x):
SimpleNet._get_int()
aaa = SimpleNet._get_int()
bbb = SimpleNet._get_int() + aaa
ccc = F.add(SimpleNet._get_int(), bbb)
x = self.ddd(ccc)
return x
def test_replacer():
"""
Feature: Class AstReplacer in Package rewrite.
Description:
Use AstReplacer to replace all "SimpleNet" symbol to "SimpleNet2" symbol.
Use AstReplacer to undo all replace.
Expectation: AstReplacer can replace all "SimpleNet" symbol to "SimpleNet2" symbol and restore original ast node.
"""
original_code = inspect.getsource(SimpleNet)
assert len(re.findall("SimpleNet", original_code)) == 11
assert len(re.findall("SimpleNet2", original_code)) == 1
ast_root = ast.parse(original_code)
replacer = AstReplacer(ast_root)
replacer.replace_all("SimpleNet", "SimpleNet2")
replaced_code = astunparse.unparse(ast_root)
assert len(re.findall("SimpleNet", replaced_code)) == 11
assert len(re.findall("SimpleNet2", replaced_code)) == 11
replacer.undo_all()
assert len(re.findall("SimpleNet", original_code)) == 11
assert len(re.findall("SimpleNet2", original_code)) == 1