!48071 fix rewrite if

Merge pull request !48071 from 于振华/fix_rewrite_if_0119
This commit is contained in:
i-robot 2023-02-16 06:36:14 +00:00 committed by Gitee
commit 40612d7235
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 46 additions and 30 deletions

View File

@ -15,7 +15,6 @@
"""Parse ast.Assign in construct function to node of SymbolTree."""
from typing import Union
import ast
import os
import astunparse
from mindspore import log as logger
@ -537,12 +536,9 @@ class AssignParser(Parser):
f"ast.Attribute, ast.Num, ast.NameConstant, ast.Bytes, ast.Str, ast.Tuple, ast.List, "
f"ast.Dict) as value of ast.assign, but got ast type '{type(value).__name__}'",
child_node=value, father_node=node))
except RuntimeError as e:
if os.getenv("STREE_PYTHON_FALLBACK"):
logger.info(f"ops-call({astunparse.unparse(node)}) not supported in rewrite, fallback to python")
stree.try_append_python_node(node, node)
else:
raise e
except RuntimeError:
logger.info(f"ops-call({astunparse.unparse(node)}) not supported in rewrite, fallback to python")
stree.try_append_python_node(node, node)
g_assign_parser = reg_parser(AssignParser())

View File

@ -67,11 +67,12 @@ class ForParser(Parser):
iter_code = iter_code.replace("self", "stree.get_origin_network()")
try:
iter_obj = eval(iter_code)
except Exception as e:
error_info = f"For MindSpore Rewrtie, when eval '{iter_code}' by using JIT Fallback feature, " \
except (NameError, TypeError) as e:
_info = f"For MindSpore Rewrtie, when eval '{iter_code}' by using JIT Fallback feature, " \
f"an error occurred: {str(e)}"
logger.error(error_info)
raise e
logger.warning(_info)
stree.try_append_python_node(node, node)
return
iter_var_name = iter_code.split(".")[-1]
index = stree.get_ast_root().body.index(node) + 1
@ -85,14 +86,17 @@ class ForParser(Parser):
index += 1
if stree.get_ori_cls_name() == "SequentialCell":
stree.on_change(Event.CodeChangeEvent)
elif isinstance(iter_obj, range):
raise NotImplementedError("For MindSpore Rewrtie, range not support")
stree.get_ast_root().body.remove(node)
return
if isinstance(iter_obj, range):
logger.warning("For MindSpore Rewrtie, range not support.")
elif isinstance(iter_obj, zip):
raise NotImplementedError("For MindSpore Rewrtie, zip not support")
logger.warning("For MindSpore Rewrtie, zip not support.")
elif isinstance(iter_obj, enumerate):
raise NotImplementedError("For MindSpore Rewrtie, enumerate not support")
logger.warning("For MindSpore Rewrtie, enumerate not support.")
else:
raise ValueError("For MindSpore Rewrtie, not supported type: ", iter_obj)
logger.warning("For MindSpore Rewrtie, not supported type: ", type(iter_obj))
stree.try_append_python_node(node, node)
return
g_for_parser = reg_parser(ForParser())

View File

@ -44,9 +44,6 @@ class FunctionDefParser(Parser):
else:
parser.process(stree, body)
for body in node.body[::-1]:
if isinstance(body, (ast.For, ast.If)):
node.body.remove(body)
if hasattr(node, "decorator_list"):
stree.try_append_python_node(node, node.decorator_list)
if hasattr(node, "returns"):

View File

@ -46,14 +46,17 @@ class IfParser(Parser):
bodies = None
try:
test_value = eval(test_code)
except NameError:
except (NameError, TypeError):
stree.try_append_python_node(node, node)
return
bodies = node.body if test_value else node.orelse
index = stree.get_ast_root().body.index(node) + 1
info_node = ast.Name(id="# If node has bin replaced by ", lineno=0, col_offset=0, ctx=ast.Load)
exp_node = ast.Expr(value=info_node, lineno=0, col_offset=0, ctx=ast.Load)
stree.get_ast_root().body.insert(index-1, exp_node)
for body in bodies:
stree.get_ast_root().body.insert(index, body)
index += 1
stree.get_ast_root().body.remove(node)
g_if_parser = reg_parser(IfParser())

View File

@ -21,7 +21,6 @@ import ast
import importlib
import types
import time
import astunparse
from mindspore.nn import Cell
@ -1323,15 +1322,32 @@ class SymbolTree(Observer, Observable):
A class handle.
"""
self._update_container()
file_name = "new_network_{0}.py".format(int(time.time() * 10000000))
with os.fdopen(os.open(file_name, os.O_WRONLY | os.O_CREAT, stat.S_IRWXU), 'wb') as f:
file_path = os.getcwd()
file_path = os.path.join(file_path, "rewritten_network")
if not os.path.exists(file_path):
os.mkdir(file_path)
file_name = "{0}_{1}.py".format(self._opt_cls_name, id(self))
network_file = os.path.join(file_path, file_name)
with os.fdopen(os.open(network_file, os.O_WRONLY | os.O_CREAT, stat.S_IRWXU), 'wb') as f:
source = self.get_code()
f.write(source.encode('utf-8'))
f.flush()
tmp_module_path, tmp_module_file = os.path.split(file_name)
os.fsync(f)
tmp_module_path, tmp_module_file = os.path.split(network_file)
tmp_module_name = tmp_module_file[:-3]
sys.path.append(tmp_module_path)
tmp_module = importlib.import_module(tmp_module_name)
tmp_module = None
i = 0
while not tmp_module:
try:
tmp_module = importlib.import_module(tmp_module_name)
except ModuleNotFoundError:
while i > 10:
break
time.sleep(0.1)
i += 1
if not tmp_module:
logger.error(f"load module {tmp_module_name} failed.")
network_cls = getattr(tmp_module, self._opt_cls_name)
if network_cls is None:
raise RuntimeError("Can not find network class:", self._opt_cls_name)

View File

@ -17,7 +17,7 @@ import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import Tensor, Parameter
from mindspore.ops.operations.nn_ops import SparseApplyProximalGradientDescent
from mindspore.common import dtype as mstype
@ -42,7 +42,7 @@ def test_apply_proximal_gradient_descent_float32():
Expectation: success
"""
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
var = Tensor(np.array([[4.1, 7.2], [1.1, 3.0]]).astype(np.float32))
var = Parameter(Tensor(np.array([[4.1, 7.2], [1.1, 3.0]]).astype(np.float32)))
alpha = Tensor(1.0, mstype.float32)
l1 = Tensor(1.0, mstype.float32)
l2 = Tensor(0.0, mstype.float32)
@ -64,7 +64,7 @@ def test_apply_proximal_gradient_descent_float64():
Expectation: success
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
var = Tensor(np.array([[4.1, 7.2], [1.1, 3.0]]).astype(np.float64))
var = Parameter(Tensor(np.array([[4.1, 7.2], [1.1, 3.0]]).astype(np.float64)))
alpha = Tensor(1.0, mstype.float64)
l1 = Tensor(1.0, mstype.float64)
l2 = Tensor(0.0, mstype.float64)