fix rewrite bug

This commit is contained in:
yuzhenhua 2022-06-02 14:21:39 +08:00
parent a95fd90b07
commit fa60ab7c84
5 changed files with 208 additions and 19 deletions

View File

@ -253,8 +253,29 @@ class SymbolTree:
def get_network(self) -> Cell:
"""
Get modified network.
The source code of network is saved to a file, the default file name is `network_define.py`.
Returns:
A network object.
"""
return self._symbol_tree.get_network()
def set_saved_file_name(self, file_name: str):
"""
Set the name of the file used to save the network.
Args:
file_name (str): filename to be set.
"""
Validator.check_value_type("file_name", file_name, [str], "Saving network")
self._symbol_tree.set_saved_file_name(file_name)
def get_saved_file_name(self):
"""Gets the filename used to save the network."""
return self._symbol_tree.get_saved_file_name()
def save_network_to_file(self):
"""
Save the modified network to a file. Default file name is `network_define.py`.
"""
self._symbol_tree.save_network_to_file()

View File

@ -256,6 +256,75 @@ class AstModifier(ast.NodeTransformer):
ast.fix_missing_locations(result)
return result
@staticmethod
def _create_arg_by_single_value(value: ScopedValue):
"""
Create an instance of ast.Constant.
Args:
value (ScopedValue): value used to create arg.
Raises:
RuntimeError: if scope of value is not empty.
TypeError: type of arg not in [ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue]
Returns:
ast.Constant: An instance of ast.Constant
"""
if value.type in (ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue):
if value.scope:
raise RuntimeError("For arg the scope should be empty")
return ast.Constant(value=value.value, kind=None)
raise TypeError("Type of arg only support [ValueType.IntValue, ValueType.FloatValue,"
f" ValueType.StringValue], but got {type(value)}")
@staticmethod
def _create_list_or_tuple(value: ScopedValue):
"""
Create an instance of ast.List or ast.Tuple.
Args:
value (ScopedValue): value used to create ast node.
Returns:
ast.List or ast.Tuple: An instance of ast.List or ast.Tuple.
"""
elts = []
for v in value.value:
elts.append(AstModifier._create_arg_by_single_value(v))
if isinstance(value, list):
return ast.List(elts=elts)
return ast.Tuple(elts=elts)
@staticmethod
def _create_keyword(arg: str, value: ScopedValue):
"""
Create an instance of ast.keyword.
Args:
arg (str): key of keyword.
value (ScopedValue): value used to create ast.keywrod instance.
Raises:
RuntimeError: if scope of value is not empty.
TypeError: type of arg not in [ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue,
ValueType.ListValue, ValueType.TupleValue]
Returns:
ast.keyword: a instance of ast.keyword.
"""
if value.scope:
raise RuntimeError("value.scope should be empty")
if value.type in (ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue):
v = ast.Constant(value=value.value, kind=None)
elif value.type in (ValueType.ListValue, ValueType.TupleValue):
v = AstModifier._create_list_or_tuple(value)
else:
raise TypeError("Type of keyword value only support [ValueType.IntValue, ValueType.FloatValue,"
"ValueType.StringValue, ValueType.ListValue, ValueType.TupleValue],"
f"but got {type(value)}")
return ast.keyword(arg=arg, value=v)
@staticmethod
def _create_call_args(args: [ScopedValue]) -> [ast.AST]:
"""
@ -279,14 +348,14 @@ class AstModifier(ast.NodeTransformer):
if not isinstance(arg, ScopedValue):
raise TypeError("arg should be ScopedValue, got: ", type(arg))
if arg.type in (ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue):
if arg.scope:
raise RuntimeError("arg.scope should be empty")
results.append(ast.Constant(value=arg.value, kind=None))
results.append(AstModifier._create_arg_by_single_value(arg))
elif arg.type == ValueType.NamingValue:
if arg.scope:
results.append(ast.Attribute(ast.Name(arg.scope, ast.Load()), arg.value, ast.Store()))
else:
results.append(ast.Name(arg.value, ast.Store()))
elif arg.type == ValueType.ListValue or arg.type == ValueType.TupleValue:
results.append(AstModifier._create_list_or_tuple(arg))
else:
raise RuntimeError("Please handle custom-object first")
return results
@ -313,10 +382,9 @@ class AstModifier(ast.NodeTransformer):
for arg, value in kwargs.items():
if not isinstance(value, ScopedValue):
raise TypeError("value should be ScopedValue, got: ", type(value))
if value.type in (ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue):
if value.scope:
raise RuntimeError("value.scope should be empty")
results.append(ast.keyword(arg=arg, value=ast.Constant(value=value.value, kind=None)))
if value.type in (ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue,
ValueType.ListValue, ValueType.TupleValue):
results.append(AstModifier._create_keyword(arg, value))
elif value.type == ValueType.NamingValue:
if value.scope:
results.append(ast.keyword(arg=arg, value=ast.Attribute(ast.Name(value.scope, ast.Load()),

View File

@ -13,11 +13,12 @@
# limitations under the License.
# ============================================================================
"""SymbolTree class define of Rewrite according to forward function of a network."""
import stat
from typing import Optional, Union, Tuple, Any
import os
import sys
import ast
import tempfile
import importlib
import astunparse
from mindspore.nn import Cell
@ -176,6 +177,7 @@ class SymbolTree(Observer, Observable):
self._tmp_file_limits = 20
self._tmp_files = []
self._saved_file_name = "./network_define.py"
def __del__(self):
for tmp_file in self._tmp_files:
@ -923,6 +925,27 @@ class SymbolTree(Observer, Observable):
cls = self._get_cls_through_file()
return cls(self._global_vars)
def set_saved_file_name(self, file_name: str):
"""Sets the filename used to save the network."""
if file_name.endswith(".py"):
self._saved_file_name = file_name
else:
self._saved_file_name = file_name + ".py"
def get_saved_file_name(self):
"""Gets the filename used to save the network."""
return self._saved_file_name
def save_network_to_file(self):
"""Save the modified network to a file."""
abs_path = os.path.abspath(self._saved_file_name)
if os.path.isfile(abs_path):
os.remove(abs_path)
with os.fdopen(os.open(self._saved_file_name, os.O_WRONLY | os.O_CREAT, stat.S_IRWXU), 'wb') as f:
source = self.get_code()
f.write(source.encode('utf-8'))
f.flush()
def _remove_unused_import(self):
"""remove unused import in self._module_ast"""
str_checker = StrChecker(self._module_ast)
@ -1177,19 +1200,12 @@ class SymbolTree(Observer, Observable):
Returns:
A class handle.
"""
source = self.get_code()
tmp_file = tempfile.NamedTemporaryFile(suffix='.py')
tmp_file.write(source.encode('utf8'))
tmp_file.flush()
tmp_file_name = tmp_file.name
if len(self._tmp_files) >= self._tmp_file_limits:
raise RuntimeError(f"Too many tmp file generated, it may caused by calling get_network method too much "
f"times. Only support open {self._tmp_file_limits} tmp file at most now!")
self._tmp_files.append(tmp_file)
tmp_module_path, tmp_module_file = os.path.split(tmp_file_name)
self.save_network_to_file()
tmp_module_path, tmp_module_file = os.path.split(self._saved_file_name)
tmp_module_name = tmp_module_file[:-3]
sys.path.append(tmp_module_path)
tmp_module = __import__(tmp_module_name)
tmp_module = importlib.import_module(tmp_module_name)
importlib.reload(tmp_module)
network_cls = getattr(tmp_module, self._opt_cls_name)
if network_cls is None:
raise RuntimeError("Can not find network class:", self._opt_cls_name)

View File

@ -128,3 +128,51 @@ def test_create_by_cell3():
"h": ScopedValue.create_variable_value(1),
"cool_boy": ScopedValue.create_naming_value('Naroto'),
}
def test_create_by_cell4():
"""
Feature: Python api create_call_buildin_op of Node of Rewrite.
Description: Call create_call_buildin_op to create a CallCell node.
Expectation: Success.
"""
node = Node.create_call_buildin_op(FakeCell3(), None, [ScopedValue.create_naming_value('x')],
ScopedValue.create_naming_value('new_conv'),
[ScopedValue.create_naming_value('x'), ScopedValue.create_naming_value("x"),
ScopedValue.create_naming_value('x'), ScopedValue.create_naming_value("x")],
{"h": ScopedValue.create_variable_value([1]),
"f": ScopedValue.create_variable_value((2,)),
"cool_boy": ScopedValue.create_naming_value('Naroto')}, 'new_conv')
assert node.get_normalized_args() == {
"a": ScopedValue.create_naming_value('x'),
"b": ScopedValue.create_naming_value('x'),
"args_2": ScopedValue.create_naming_value('x'),
"args_3": ScopedValue.create_naming_value('x'),
"f": ScopedValue.create_variable_value((2,)),
"h": ScopedValue.create_variable_value([1]),
"cool_boy": ScopedValue.create_naming_value('Naroto'),
}
def test_create_by_cell5():
"""
Feature: Python api create_call_buildin_op of Node of Rewrite.
Description: Call create_call_buildin_op to create a CallCell node.
Expectation: Success.
"""
node = Node.create_call_buildin_op(FakeCell3(), None, [ScopedValue.create_naming_value('x')],
ScopedValue.create_naming_value('new_conv'),
[ScopedValue.create_variable_value((4,)), ScopedValue.create_variable_value(5),
ScopedValue.create_variable_value([5]), ScopedValue.create_naming_value("x")],
{"h": ScopedValue.create_variable_value(1),
"f": ScopedValue.create_variable_value(2),
"cool_boy": ScopedValue.create_naming_value('Naroto')}, 'new_conv')
assert node.get_normalized_args() == {
"a": ScopedValue.create_variable_value((4,)),
"b": ScopedValue.create_variable_value(5),
"args_2": ScopedValue.create_variable_value([5]),
"args_3": ScopedValue.create_naming_value('x'),
"f": ScopedValue.create_variable_value(2),
"h": ScopedValue.create_variable_value(1),
"cool_boy": ScopedValue.create_naming_value('Naroto'),
}

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import ast
import inspect
@ -432,3 +433,38 @@ def test_replace_one_to_multi():
assert len(new_relu_node.get_targets()) == 1
assert len(relu2.get_normalized_args().values()) == 1
assert new_relu_node.get_targets()[0] == list(relu2.get_normalized_args().values())[0]
def test_set_saved_file_name():
"""
Feature: Python api set_saved_file_name and get_saved_file_name of SymbolTree of Rewrite.
Description: Call set_saved_file_name to set the filename used to save the network.
Call get_saved_file_name to get the filename used to save the network.
Expectation: Success.
"""
stree, _, _, _ = create_symbol_tree()
stree.set_saved_file_name("new_network.py")
new_file_name = stree.get_saved_file_name()
assert new_file_name == "new_network.py"
stree.set_saved_file_name("new_network_01")
new_file_name = stree.get_saved_file_name()
assert new_file_name == "new_network_01.py"
def test_save_network_to_file():
"""
Feature: Python api save_network_to_file of SymbolTree of Rewrite.
Description: Call save_network_to_file to save the network to a file.
Expectation: Success.
"""
stree, bn, relu1, relu2 = create_symbol_tree()
stree.set_node_arg_by_node(relu2, 0, bn)
stree.erase_node(relu1)
stree.set_saved_file_name("new_network.py")
stree.save_network_to_file()
assert os.path.exists("./new_network.py")
os.system("rm -f new_network.py")