fix rewrite bug

This commit is contained in:
yuzhenhua 2022-06-02 14:21:39 +08:00
parent a95fd90b07
commit fa60ab7c84
5 changed files with 208 additions and 19 deletions

View File

@ -253,8 +253,29 @@ class SymbolTree:
def get_network(self) -> Cell: def get_network(self) -> Cell:
""" """
Get modified network. Get modified network.
The source code of network is saved to a file, the default file name is `network_define.py`.
Returns: Returns:
A network object. A network object.
""" """
return self._symbol_tree.get_network() return self._symbol_tree.get_network()
def set_saved_file_name(self, file_name: str):
"""
Set the name of the file used to save the network.
Args:
file_name (str): filename to be set.
"""
Validator.check_value_type("file_name", file_name, [str], "Saving network")
self._symbol_tree.set_saved_file_name(file_name)
def get_saved_file_name(self):
"""Gets the filename used to save the network."""
return self._symbol_tree.get_saved_file_name()
def save_network_to_file(self):
"""
Save the modified network to a file. Default file name is `network_define.py`.
"""
self._symbol_tree.save_network_to_file()

View File

@ -256,6 +256,75 @@ class AstModifier(ast.NodeTransformer):
ast.fix_missing_locations(result) ast.fix_missing_locations(result)
return result return result
@staticmethod
def _create_arg_by_single_value(value: ScopedValue):
"""
Create an instance of ast.Constant.
Args:
value (ScopedValue): value used to create arg.
Raises:
RuntimeError: if scope of value is not empty.
TypeError: type of arg not in [ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue]
Returns:
ast.Constant: An instance of ast.Constant
"""
if value.type in (ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue):
if value.scope:
raise RuntimeError("For arg the scope should be empty")
return ast.Constant(value=value.value, kind=None)
raise TypeError("Type of arg only support [ValueType.IntValue, ValueType.FloatValue,"
f" ValueType.StringValue], but got {type(value)}")
@staticmethod
def _create_list_or_tuple(value: ScopedValue):
"""
Create an instance of ast.List or ast.Tuple.
Args:
value (ScopedValue): value used to create ast node.
Returns:
ast.List or ast.Tuple: An instance of ast.List or ast.Tuple.
"""
elts = []
for v in value.value:
elts.append(AstModifier._create_arg_by_single_value(v))
if isinstance(value, list):
return ast.List(elts=elts)
return ast.Tuple(elts=elts)
@staticmethod
def _create_keyword(arg: str, value: ScopedValue):
"""
Create an instance of ast.keyword.
Args:
arg (str): key of keyword.
value (ScopedValue): value used to create ast.keywrod instance.
Raises:
RuntimeError: if scope of value is not empty.
TypeError: type of arg not in [ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue,
ValueType.ListValue, ValueType.TupleValue]
Returns:
ast.keyword: a instance of ast.keyword.
"""
if value.scope:
raise RuntimeError("value.scope should be empty")
if value.type in (ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue):
v = ast.Constant(value=value.value, kind=None)
elif value.type in (ValueType.ListValue, ValueType.TupleValue):
v = AstModifier._create_list_or_tuple(value)
else:
raise TypeError("Type of keyword value only support [ValueType.IntValue, ValueType.FloatValue,"
"ValueType.StringValue, ValueType.ListValue, ValueType.TupleValue],"
f"but got {type(value)}")
return ast.keyword(arg=arg, value=v)
@staticmethod @staticmethod
def _create_call_args(args: [ScopedValue]) -> [ast.AST]: def _create_call_args(args: [ScopedValue]) -> [ast.AST]:
""" """
@ -279,14 +348,14 @@ class AstModifier(ast.NodeTransformer):
if not isinstance(arg, ScopedValue): if not isinstance(arg, ScopedValue):
raise TypeError("arg should be ScopedValue, got: ", type(arg)) raise TypeError("arg should be ScopedValue, got: ", type(arg))
if arg.type in (ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue): if arg.type in (ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue):
if arg.scope: results.append(AstModifier._create_arg_by_single_value(arg))
raise RuntimeError("arg.scope should be empty")
results.append(ast.Constant(value=arg.value, kind=None))
elif arg.type == ValueType.NamingValue: elif arg.type == ValueType.NamingValue:
if arg.scope: if arg.scope:
results.append(ast.Attribute(ast.Name(arg.scope, ast.Load()), arg.value, ast.Store())) results.append(ast.Attribute(ast.Name(arg.scope, ast.Load()), arg.value, ast.Store()))
else: else:
results.append(ast.Name(arg.value, ast.Store())) results.append(ast.Name(arg.value, ast.Store()))
elif arg.type == ValueType.ListValue or arg.type == ValueType.TupleValue:
results.append(AstModifier._create_list_or_tuple(arg))
else: else:
raise RuntimeError("Please handle custom-object first") raise RuntimeError("Please handle custom-object first")
return results return results
@ -313,10 +382,9 @@ class AstModifier(ast.NodeTransformer):
for arg, value in kwargs.items(): for arg, value in kwargs.items():
if not isinstance(value, ScopedValue): if not isinstance(value, ScopedValue):
raise TypeError("value should be ScopedValue, got: ", type(value)) raise TypeError("value should be ScopedValue, got: ", type(value))
if value.type in (ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue): if value.type in (ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue,
if value.scope: ValueType.ListValue, ValueType.TupleValue):
raise RuntimeError("value.scope should be empty") results.append(AstModifier._create_keyword(arg, value))
results.append(ast.keyword(arg=arg, value=ast.Constant(value=value.value, kind=None)))
elif value.type == ValueType.NamingValue: elif value.type == ValueType.NamingValue:
if value.scope: if value.scope:
results.append(ast.keyword(arg=arg, value=ast.Attribute(ast.Name(value.scope, ast.Load()), results.append(ast.keyword(arg=arg, value=ast.Attribute(ast.Name(value.scope, ast.Load()),

View File

@ -13,11 +13,12 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""SymbolTree class define of Rewrite according to forward function of a network.""" """SymbolTree class define of Rewrite according to forward function of a network."""
import stat
from typing import Optional, Union, Tuple, Any from typing import Optional, Union, Tuple, Any
import os import os
import sys import sys
import ast import ast
import tempfile import importlib
import astunparse import astunparse
from mindspore.nn import Cell from mindspore.nn import Cell
@ -176,6 +177,7 @@ class SymbolTree(Observer, Observable):
self._tmp_file_limits = 20 self._tmp_file_limits = 20
self._tmp_files = [] self._tmp_files = []
self._saved_file_name = "./network_define.py"
def __del__(self): def __del__(self):
for tmp_file in self._tmp_files: for tmp_file in self._tmp_files:
@ -923,6 +925,27 @@ class SymbolTree(Observer, Observable):
cls = self._get_cls_through_file() cls = self._get_cls_through_file()
return cls(self._global_vars) return cls(self._global_vars)
def set_saved_file_name(self, file_name: str):
"""Sets the filename used to save the network."""
if file_name.endswith(".py"):
self._saved_file_name = file_name
else:
self._saved_file_name = file_name + ".py"
def get_saved_file_name(self):
"""Gets the filename used to save the network."""
return self._saved_file_name
def save_network_to_file(self):
"""Save the modified network to a file."""
abs_path = os.path.abspath(self._saved_file_name)
if os.path.isfile(abs_path):
os.remove(abs_path)
with os.fdopen(os.open(self._saved_file_name, os.O_WRONLY | os.O_CREAT, stat.S_IRWXU), 'wb') as f:
source = self.get_code()
f.write(source.encode('utf-8'))
f.flush()
def _remove_unused_import(self): def _remove_unused_import(self):
"""remove unused import in self._module_ast""" """remove unused import in self._module_ast"""
str_checker = StrChecker(self._module_ast) str_checker = StrChecker(self._module_ast)
@ -1177,19 +1200,12 @@ class SymbolTree(Observer, Observable):
Returns: Returns:
A class handle. A class handle.
""" """
source = self.get_code() self.save_network_to_file()
tmp_file = tempfile.NamedTemporaryFile(suffix='.py') tmp_module_path, tmp_module_file = os.path.split(self._saved_file_name)
tmp_file.write(source.encode('utf8'))
tmp_file.flush()
tmp_file_name = tmp_file.name
if len(self._tmp_files) >= self._tmp_file_limits:
raise RuntimeError(f"Too many tmp file generated, it may caused by calling get_network method too much "
f"times. Only support open {self._tmp_file_limits} tmp file at most now!")
self._tmp_files.append(tmp_file)
tmp_module_path, tmp_module_file = os.path.split(tmp_file_name)
tmp_module_name = tmp_module_file[:-3] tmp_module_name = tmp_module_file[:-3]
sys.path.append(tmp_module_path) sys.path.append(tmp_module_path)
tmp_module = __import__(tmp_module_name) tmp_module = importlib.import_module(tmp_module_name)
importlib.reload(tmp_module)
network_cls = getattr(tmp_module, self._opt_cls_name) network_cls = getattr(tmp_module, self._opt_cls_name)
if network_cls is None: if network_cls is None:
raise RuntimeError("Can not find network class:", self._opt_cls_name) raise RuntimeError("Can not find network class:", self._opt_cls_name)

View File

@ -128,3 +128,51 @@ def test_create_by_cell3():
"h": ScopedValue.create_variable_value(1), "h": ScopedValue.create_variable_value(1),
"cool_boy": ScopedValue.create_naming_value('Naroto'), "cool_boy": ScopedValue.create_naming_value('Naroto'),
} }
def test_create_by_cell4():
"""
Feature: Python api create_call_buildin_op of Node of Rewrite.
Description: Call create_call_buildin_op to create a CallCell node.
Expectation: Success.
"""
node = Node.create_call_buildin_op(FakeCell3(), None, [ScopedValue.create_naming_value('x')],
ScopedValue.create_naming_value('new_conv'),
[ScopedValue.create_naming_value('x'), ScopedValue.create_naming_value("x"),
ScopedValue.create_naming_value('x'), ScopedValue.create_naming_value("x")],
{"h": ScopedValue.create_variable_value([1]),
"f": ScopedValue.create_variable_value((2,)),
"cool_boy": ScopedValue.create_naming_value('Naroto')}, 'new_conv')
assert node.get_normalized_args() == {
"a": ScopedValue.create_naming_value('x'),
"b": ScopedValue.create_naming_value('x'),
"args_2": ScopedValue.create_naming_value('x'),
"args_3": ScopedValue.create_naming_value('x'),
"f": ScopedValue.create_variable_value((2,)),
"h": ScopedValue.create_variable_value([1]),
"cool_boy": ScopedValue.create_naming_value('Naroto'),
}
def test_create_by_cell5():
"""
Feature: Python api create_call_buildin_op of Node of Rewrite.
Description: Call create_call_buildin_op to create a CallCell node.
Expectation: Success.
"""
node = Node.create_call_buildin_op(FakeCell3(), None, [ScopedValue.create_naming_value('x')],
ScopedValue.create_naming_value('new_conv'),
[ScopedValue.create_variable_value((4,)), ScopedValue.create_variable_value(5),
ScopedValue.create_variable_value([5]), ScopedValue.create_naming_value("x")],
{"h": ScopedValue.create_variable_value(1),
"f": ScopedValue.create_variable_value(2),
"cool_boy": ScopedValue.create_naming_value('Naroto')}, 'new_conv')
assert node.get_normalized_args() == {
"a": ScopedValue.create_variable_value((4,)),
"b": ScopedValue.create_variable_value(5),
"args_2": ScopedValue.create_variable_value([5]),
"args_3": ScopedValue.create_naming_value('x'),
"f": ScopedValue.create_variable_value(2),
"h": ScopedValue.create_variable_value(1),
"cool_boy": ScopedValue.create_naming_value('Naroto'),
}

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
import os
import ast import ast
import inspect import inspect
@ -432,3 +433,38 @@ def test_replace_one_to_multi():
assert len(new_relu_node.get_targets()) == 1 assert len(new_relu_node.get_targets()) == 1
assert len(relu2.get_normalized_args().values()) == 1 assert len(relu2.get_normalized_args().values()) == 1
assert new_relu_node.get_targets()[0] == list(relu2.get_normalized_args().values())[0] assert new_relu_node.get_targets()[0] == list(relu2.get_normalized_args().values())[0]
def test_set_saved_file_name():
"""
Feature: Python api set_saved_file_name and get_saved_file_name of SymbolTree of Rewrite.
Description: Call set_saved_file_name to set the filename used to save the network.
Call get_saved_file_name to get the filename used to save the network.
Expectation: Success.
"""
stree, _, _, _ = create_symbol_tree()
stree.set_saved_file_name("new_network.py")
new_file_name = stree.get_saved_file_name()
assert new_file_name == "new_network.py"
stree.set_saved_file_name("new_network_01")
new_file_name = stree.get_saved_file_name()
assert new_file_name == "new_network_01.py"
def test_save_network_to_file():
"""
Feature: Python api save_network_to_file of SymbolTree of Rewrite.
Description: Call save_network_to_file to save the network to a file.
Expectation: Success.
"""
stree, bn, relu1, relu2 = create_symbol_tree()
stree.set_node_arg_by_node(relu2, 0, bn)
stree.erase_node(relu1)
stree.set_saved_file_name("new_network.py")
stree.save_network_to_file()
assert os.path.exists("./new_network.py")
os.system("rm -f new_network.py")