forked from mindspore-Ecosystem/mindspore
!48071 fix rewrite if
Merge pull request !48071 from 于振华/fix_rewrite_if_0119
This commit is contained in:
commit
40612d7235
|
@ -15,7 +15,6 @@
|
|||
"""Parse ast.Assign in construct function to node of SymbolTree."""
|
||||
from typing import Union
|
||||
import ast
|
||||
import os
|
||||
import astunparse
|
||||
|
||||
from mindspore import log as logger
|
||||
|
@ -537,12 +536,9 @@ class AssignParser(Parser):
|
|||
f"ast.Attribute, ast.Num, ast.NameConstant, ast.Bytes, ast.Str, ast.Tuple, ast.List, "
|
||||
f"ast.Dict) as value of ast.assign, but got ast type '{type(value).__name__}'",
|
||||
child_node=value, father_node=node))
|
||||
except RuntimeError as e:
|
||||
if os.getenv("STREE_PYTHON_FALLBACK"):
|
||||
logger.info(f"ops-call({astunparse.unparse(node)}) not supported in rewrite, fallback to python")
|
||||
stree.try_append_python_node(node, node)
|
||||
else:
|
||||
raise e
|
||||
except RuntimeError:
|
||||
logger.info(f"ops-call({astunparse.unparse(node)}) not supported in rewrite, fallback to python")
|
||||
stree.try_append_python_node(node, node)
|
||||
|
||||
|
||||
g_assign_parser = reg_parser(AssignParser())
|
||||
|
|
|
@ -67,11 +67,12 @@ class ForParser(Parser):
|
|||
iter_code = iter_code.replace("self", "stree.get_origin_network()")
|
||||
try:
|
||||
iter_obj = eval(iter_code)
|
||||
except Exception as e:
|
||||
error_info = f"For MindSpore Rewrtie, when eval '{iter_code}' by using JIT Fallback feature, " \
|
||||
except (NameError, TypeError) as e:
|
||||
_info = f"For MindSpore Rewrtie, when eval '{iter_code}' by using JIT Fallback feature, " \
|
||||
f"an error occurred: {str(e)}"
|
||||
logger.error(error_info)
|
||||
raise e
|
||||
logger.warning(_info)
|
||||
stree.try_append_python_node(node, node)
|
||||
return
|
||||
|
||||
iter_var_name = iter_code.split(".")[-1]
|
||||
index = stree.get_ast_root().body.index(node) + 1
|
||||
|
@ -85,14 +86,17 @@ class ForParser(Parser):
|
|||
index += 1
|
||||
if stree.get_ori_cls_name() == "SequentialCell":
|
||||
stree.on_change(Event.CodeChangeEvent)
|
||||
elif isinstance(iter_obj, range):
|
||||
raise NotImplementedError("For MindSpore Rewrtie, range not support")
|
||||
stree.get_ast_root().body.remove(node)
|
||||
return
|
||||
if isinstance(iter_obj, range):
|
||||
logger.warning("For MindSpore Rewrtie, range not support.")
|
||||
elif isinstance(iter_obj, zip):
|
||||
raise NotImplementedError("For MindSpore Rewrtie, zip not support")
|
||||
logger.warning("For MindSpore Rewrtie, zip not support.")
|
||||
elif isinstance(iter_obj, enumerate):
|
||||
raise NotImplementedError("For MindSpore Rewrtie, enumerate not support")
|
||||
logger.warning("For MindSpore Rewrtie, enumerate not support.")
|
||||
else:
|
||||
raise ValueError("For MindSpore Rewrtie, not supported type: ", iter_obj)
|
||||
|
||||
logger.warning("For MindSpore Rewrtie, not supported type: ", type(iter_obj))
|
||||
stree.try_append_python_node(node, node)
|
||||
return
|
||||
|
||||
g_for_parser = reg_parser(ForParser())
|
||||
|
|
|
@ -44,9 +44,6 @@ class FunctionDefParser(Parser):
|
|||
else:
|
||||
parser.process(stree, body)
|
||||
|
||||
for body in node.body[::-1]:
|
||||
if isinstance(body, (ast.For, ast.If)):
|
||||
node.body.remove(body)
|
||||
if hasattr(node, "decorator_list"):
|
||||
stree.try_append_python_node(node, node.decorator_list)
|
||||
if hasattr(node, "returns"):
|
||||
|
|
|
@ -46,14 +46,17 @@ class IfParser(Parser):
|
|||
bodies = None
|
||||
try:
|
||||
test_value = eval(test_code)
|
||||
except NameError:
|
||||
except (NameError, TypeError):
|
||||
stree.try_append_python_node(node, node)
|
||||
return
|
||||
|
||||
bodies = node.body if test_value else node.orelse
|
||||
index = stree.get_ast_root().body.index(node) + 1
|
||||
info_node = ast.Name(id="# If node has bin replaced by ", lineno=0, col_offset=0, ctx=ast.Load)
|
||||
exp_node = ast.Expr(value=info_node, lineno=0, col_offset=0, ctx=ast.Load)
|
||||
stree.get_ast_root().body.insert(index-1, exp_node)
|
||||
for body in bodies:
|
||||
stree.get_ast_root().body.insert(index, body)
|
||||
index += 1
|
||||
|
||||
stree.get_ast_root().body.remove(node)
|
||||
g_if_parser = reg_parser(IfParser())
|
||||
|
|
|
@ -21,7 +21,6 @@ import ast
|
|||
import importlib
|
||||
import types
|
||||
import time
|
||||
|
||||
import astunparse
|
||||
|
||||
from mindspore.nn import Cell
|
||||
|
@ -1323,15 +1322,32 @@ class SymbolTree(Observer, Observable):
|
|||
A class handle.
|
||||
"""
|
||||
self._update_container()
|
||||
file_name = "new_network_{0}.py".format(int(time.time() * 10000000))
|
||||
with os.fdopen(os.open(file_name, os.O_WRONLY | os.O_CREAT, stat.S_IRWXU), 'wb') as f:
|
||||
file_path = os.getcwd()
|
||||
file_path = os.path.join(file_path, "rewritten_network")
|
||||
if not os.path.exists(file_path):
|
||||
os.mkdir(file_path)
|
||||
file_name = "{0}_{1}.py".format(self._opt_cls_name, id(self))
|
||||
network_file = os.path.join(file_path, file_name)
|
||||
with os.fdopen(os.open(network_file, os.O_WRONLY | os.O_CREAT, stat.S_IRWXU), 'wb') as f:
|
||||
source = self.get_code()
|
||||
f.write(source.encode('utf-8'))
|
||||
f.flush()
|
||||
tmp_module_path, tmp_module_file = os.path.split(file_name)
|
||||
os.fsync(f)
|
||||
tmp_module_path, tmp_module_file = os.path.split(network_file)
|
||||
tmp_module_name = tmp_module_file[:-3]
|
||||
sys.path.append(tmp_module_path)
|
||||
tmp_module = importlib.import_module(tmp_module_name)
|
||||
tmp_module = None
|
||||
i = 0
|
||||
while not tmp_module:
|
||||
try:
|
||||
tmp_module = importlib.import_module(tmp_module_name)
|
||||
except ModuleNotFoundError:
|
||||
while i > 10:
|
||||
break
|
||||
time.sleep(0.1)
|
||||
i += 1
|
||||
if not tmp_module:
|
||||
logger.error(f"load module {tmp_module_name} failed.")
|
||||
network_cls = getattr(tmp_module, self._opt_cls_name)
|
||||
if network_cls is None:
|
||||
raise RuntimeError("Can not find network class:", self._opt_cls_name)
|
||||
|
|
|
@ -17,7 +17,7 @@ import numpy as np
|
|||
import pytest
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore.ops.operations.nn_ops import SparseApplyProximalGradientDescent
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
|
@ -42,7 +42,7 @@ def test_apply_proximal_gradient_descent_float32():
|
|||
Expectation: success
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
var = Tensor(np.array([[4.1, 7.2], [1.1, 3.0]]).astype(np.float32))
|
||||
var = Parameter(Tensor(np.array([[4.1, 7.2], [1.1, 3.0]]).astype(np.float32)))
|
||||
alpha = Tensor(1.0, mstype.float32)
|
||||
l1 = Tensor(1.0, mstype.float32)
|
||||
l2 = Tensor(0.0, mstype.float32)
|
||||
|
@ -64,7 +64,7 @@ def test_apply_proximal_gradient_descent_float64():
|
|||
Expectation: success
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
var = Tensor(np.array([[4.1, 7.2], [1.1, 3.0]]).astype(np.float64))
|
||||
var = Parameter(Tensor(np.array([[4.1, 7.2], [1.1, 3.0]]).astype(np.float64)))
|
||||
alpha = Tensor(1.0, mstype.float64)
|
||||
l1 = Tensor(1.0, mstype.float64)
|
||||
l2 = Tensor(0.0, mstype.float64)
|
||||
|
|
Loading…
Reference in New Issue