forked from mindspore-Ecosystem/mindspore
!13142 support name or attribute ast.Expr
From: @zhangbuxue Reviewed-by: Signed-off-by:
This commit is contained in:
commit
1e33df94a5
|
@ -21,7 +21,7 @@ from .parser import (Parser, create_obj_instance, generate_scope,
|
|||
get_class_member_namespace_symbol, create_slice_obj,
|
||||
get_dataclass_attributes, get_dataclass_methods, get_obj_id,
|
||||
get_module_namespace, get_obj_type, get_object_key,
|
||||
get_parse_method_of_class, get_scope_name,
|
||||
get_parse_method_of_class, get_scope_name, expand_expr_statement,
|
||||
is_class_member, parse_cb, resolve_symbol, convert_to_ms_tensor, get_object_description)
|
||||
from .serialize import *
|
||||
|
||||
|
@ -30,4 +30,4 @@ __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class',
|
|||
'get_obj_type', 'get_obj_id', 'create_obj_instance', 'get_module_namespace',
|
||||
'get_class_member_namespace_symbol', 'get_obj_id', 'Parser', 'get_dataclass_attributes',
|
||||
'get_dataclass_methods', 'dump_obj', 'load_obj', 'get_dataclass_methods', 'get_scope_name',
|
||||
'create_slice_obj', 'convert_to_ms_tensor', 'get_object_description']
|
||||
'create_slice_obj', 'convert_to_ms_tensor', 'get_object_description', 'expand_expr_statement']
|
||||
|
|
|
@ -347,6 +347,30 @@ def get_object_description(obj, fname, fline):
|
|||
return str(obj)
|
||||
|
||||
|
||||
def expand_expr_statement(node):
|
||||
"""
|
||||
Process the expr statement and expand it.
|
||||
|
||||
Returns:
|
||||
tuple, (True, expr.value, x)/(False, None, None).
|
||||
"""
|
||||
if isinstance(node, ast.Expr):
|
||||
expr_value = node.value
|
||||
if isinstance(expr_value, ast.Call):
|
||||
func = expr_value.func
|
||||
if isinstance(func, ast.Attribute) and \
|
||||
hasattr(func, "attr") and \
|
||||
hasattr(func, "value"):
|
||||
method = func.attr
|
||||
target = func.value
|
||||
if method in parse_expr_statement_white_list:
|
||||
logger.debug("Expand expr, target:%s, method:%s", target, method)
|
||||
return True, expr_value, target
|
||||
if not isinstance(expr_value, ast.Str):
|
||||
return True, expr_value
|
||||
return (False,)
|
||||
|
||||
|
||||
class Parser:
|
||||
"""
|
||||
Parser python code to ast tree.
|
||||
|
@ -548,25 +572,3 @@ class Parser:
|
|||
else:
|
||||
ret = ret + [0, 0, 0, 0]
|
||||
return ret
|
||||
|
||||
def expand_expr_statement(self, node):
|
||||
"""
|
||||
Process the expr statement and expand it.
|
||||
|
||||
Returns:
|
||||
tuple, (True, expr.value, x)/(False, None, None).
|
||||
"""
|
||||
if isinstance(node, ast.Expr) and hasattr(node, "value"):
|
||||
expr_value = node.value
|
||||
if isinstance(expr_value, ast.Call):
|
||||
func = expr_value.func
|
||||
if isinstance(func, ast.Attribute) and \
|
||||
hasattr(func, "attr") and \
|
||||
hasattr(func, "value"):
|
||||
method = func.attr
|
||||
target = func.value
|
||||
if method in parse_expr_statement_white_list:
|
||||
logger.debug("Expand expr, target:%s, method:%s", target, method)
|
||||
return True, expr_value, target
|
||||
return True, expr_value
|
||||
return False, None, None
|
||||
|
|
|
@ -418,7 +418,7 @@ AnfNodePtr Parser::ParseExprNode(const FunctionBlockPtr &block, const py::object
|
|||
FunctionBlockPtr Parser::ParseExpr(const FunctionBlockPtr &block, const py::object &node) {
|
||||
MS_LOG(DEBUG) << "Process ast Expr";
|
||||
// Expr only have value, no target
|
||||
py::tuple expand_info = ast_->CallParserObjMethod(PYTHON_PARSE_EXPAND_EXPR_STATEMENT, node);
|
||||
py::tuple expand_info = ast_->CallParseModFunction(PYTHON_PARSE_EXPAND_EXPR_STATEMENT, node);
|
||||
|
||||
// Refer python function expand_expr_statement, expand_info is one of the following:
|
||||
// True, expr.value, x
|
||||
|
|
|
@ -60,8 +60,7 @@ bool SymbolResolver::Resolve() {
|
|||
py::object obj = namespace_->obj();
|
||||
std::string symbol = symbol_->symbol();
|
||||
if (py::isinstance<py::none>(obj)) {
|
||||
MS_LOG(ERROR) << "Unresolved symbol: " << symbol;
|
||||
return false;
|
||||
MS_EXCEPTION(NameError) << "The name \'" << symbol << "\' is not defined.";
|
||||
}
|
||||
result_ = python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_FUNCTION, obj, common::SafeCStr(symbol));
|
||||
return true;
|
||||
|
@ -294,10 +293,7 @@ AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr
|
|||
MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " graph or manager is nullptr";
|
||||
}
|
||||
SymbolResolver symbol_resolver(name_space, symbol, node);
|
||||
if (!symbol_resolver.Resolve()) {
|
||||
MS_EXCEPTION(TypeError) << "Parse Resolve node failed NodeInfo.";
|
||||
}
|
||||
|
||||
symbol_resolver.Resolve();
|
||||
py::object obj = symbol_resolver.result();
|
||||
AnfNodePtr resolved_node = ResolveObjectAndAddToManager(manager, obj, node);
|
||||
TraceManager::ClearParseOrResolveDebugInfo();
|
||||
|
|
|
@ -629,6 +629,9 @@ bool ExecutorPy::Compile(const py::object &obj, const py::tuple &args, const py:
|
|||
} catch (const py::attribute_error &ex) {
|
||||
ReleaseResource(phase);
|
||||
throw py::attribute_error(ex);
|
||||
} catch (const py::name_error &ex) {
|
||||
ReleaseResource(phase);
|
||||
throw py::name_error(ex);
|
||||
} catch (const std::exception &ex) {
|
||||
ReleaseResource(phase);
|
||||
// re-throw this exception to Python interpreter to handle it
|
||||
|
|
|
@ -45,6 +45,9 @@ class PyExceptionInitializer {
|
|||
if (exception_type == AttributeError) {
|
||||
throw py::attribute_error(str);
|
||||
}
|
||||
if (exception_type == NameError) {
|
||||
throw py::name_error(str);
|
||||
}
|
||||
py::pybind11_fail(str);
|
||||
}
|
||||
};
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
namespace pybind11 {
|
||||
PYBIND11_RUNTIME_EXCEPTION(attribute_error, PyExc_AttributeError)
|
||||
}
|
||||
PYBIND11_RUNTIME_EXCEPTION(name_error, PyExc_NameError)
|
||||
} // namespace pybind11
|
||||
|
||||
#endif // PYBIND_API_PYBIND_PATCH_H_
|
||||
|
|
|
@ -95,36 +95,6 @@ static int GetSlogLevel(MsLogLevel level) {
|
|||
}
|
||||
#endif
|
||||
|
||||
static std::string ExceptionTypeToString(ExceptionType type) {
|
||||
#define _TO_STRING(x) #x
|
||||
// clang-format off
|
||||
static const char *const type_names[] = {
|
||||
_TO_STRING(NoExceptionType),
|
||||
_TO_STRING(UnknownError),
|
||||
_TO_STRING(ArgumentError),
|
||||
_TO_STRING(NotSupportError),
|
||||
_TO_STRING(NotExistsError),
|
||||
_TO_STRING(AlreadyExistsError),
|
||||
_TO_STRING(UnavailableError),
|
||||
_TO_STRING(DeviceProcessError),
|
||||
_TO_STRING(AbortedError),
|
||||
_TO_STRING(TimeOutError),
|
||||
_TO_STRING(ResourceUnavailable),
|
||||
_TO_STRING(NoPermissionError),
|
||||
_TO_STRING(IndexError),
|
||||
_TO_STRING(ValueError),
|
||||
_TO_STRING(TypeError),
|
||||
_TO_STRING(KeyError),
|
||||
_TO_STRING(AttributeError),
|
||||
};
|
||||
// clang-format on
|
||||
#undef _TO_STRING
|
||||
if (type < UnknownError || type > AttributeError) {
|
||||
type = UnknownError;
|
||||
}
|
||||
return std::string(type_names[type]);
|
||||
}
|
||||
|
||||
static const char *GetSubModuleName(SubModuleId module_id) {
|
||||
static const char *sub_module_names[NUM_SUBMODUES] = {
|
||||
"UNKNOWN", // SM_UNKNOWN
|
||||
|
@ -185,10 +155,6 @@ void LogWriter::operator^(const LogStream &stream) const {
|
|||
|
||||
std::ostringstream oss;
|
||||
oss << location_.file_ << ":" << location_.line_ << " " << location_.func_ << "] ";
|
||||
if (exception_type_ != NoExceptionType && exception_type_ != IndexError && exception_type_ != TypeError &&
|
||||
exception_type_ != ValueError && exception_type_ != KeyError && exception_type_ != AttributeError) {
|
||||
oss << ExceptionTypeToString(exception_type_) << " ";
|
||||
}
|
||||
oss << msg.str();
|
||||
|
||||
if (trace_provider_ != nullptr) {
|
||||
|
|
|
@ -60,6 +60,7 @@ enum ExceptionType {
|
|||
TypeError,
|
||||
KeyError,
|
||||
AttributeError,
|
||||
NameError
|
||||
};
|
||||
|
||||
struct LocationInfo {
|
||||
|
|
|
@ -162,7 +162,7 @@ def test_sequential_resolve_error():
|
|||
input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
|
||||
input_me = Tensor(input_np)
|
||||
net = SequenceNet()
|
||||
with pytest.raises(TypeError):
|
||||
with pytest.raises(NameError):
|
||||
net(input_me)
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,98 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test use undefined var"""
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
def test_use_undefined_var():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.value = [11, 22, 33, 44]
|
||||
|
||||
def construct(self, x):
|
||||
ret = x + c
|
||||
return ret
|
||||
net = Net()
|
||||
with pytest.raises(NameError) as err:
|
||||
net(Tensor(np.arange(4)))
|
||||
assert "The name 'c' is not defined" in str(err.value)
|
||||
|
||||
|
||||
def test_insert_undefined_var():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.value = [11, 22, 33, 44]
|
||||
|
||||
def construct(self, x):
|
||||
c
|
||||
ret = x + x
|
||||
return ret
|
||||
net = Net()
|
||||
with pytest.raises(NameError) as err:
|
||||
net(Tensor(np.arange(4)))
|
||||
assert "The name 'c' is not defined" in str(err.value)
|
||||
|
||||
|
||||
def test_insert_undefined_var_compute():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.value = [11, 22, 33, 44]
|
||||
|
||||
def construct(self, x):
|
||||
c + d
|
||||
ret = x + x
|
||||
return ret
|
||||
net = Net()
|
||||
with pytest.raises(NameError) as err:
|
||||
net(Tensor(np.arange(4)))
|
||||
assert "The name 'c' is not defined" in str(err.value)
|
||||
|
||||
|
||||
def test_insert_defined_var():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.value = [11, 22, 33, 44]
|
||||
|
||||
def construct(self, x):
|
||||
x
|
||||
ret = x + x
|
||||
return ret
|
||||
net = Net()
|
||||
net(Tensor(np.arange(4)))
|
||||
|
||||
|
||||
def test_insert_defined_var_compute():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.value = [11, 22, 33, 44]
|
||||
|
||||
def construct(self, x):
|
||||
x - x
|
||||
ret = x + x
|
||||
return ret
|
||||
net = Net()
|
||||
net(Tensor(np.arange(4)))
|
Loading…
Reference in New Issue