show accurate error line when use user defined class

This commit is contained in:
buxue 2021-04-01 21:29:51 +08:00
parent 6db503ec9d
commit 6f1105ea79
2 changed files with 33 additions and 6 deletions

View File

@ -168,16 +168,19 @@ AnfNodePtr FunctionBlock::MakeResolveSymbol(const std::string &value) {
auto bits_str = value.substr(start);
return MakeResolveClassMember(bits_str);
}
py::tuple namespace_var = parser_.ast()->CallParserObjMethod(PYTHON_PARSE_GET_NAMESPACE_SYMBOL, value);
if (namespace_var[0].is_none()) {
if (namespace_var.size() >= 3) {
MS_EXCEPTION(NameError) << namespace_var[2].cast<std::string>();
py::tuple namespace_info = parser_.ast()->CallParserObjMethod(PYTHON_PARSE_GET_NAMESPACE_SYMBOL, value);
// If namespace is None, the symbol is an undefined name or an unsupported builtin function.
if (namespace_info[0].is_none()) {
// If the size of namespace_var is greater than or equal to 3, the error information is stored in namespace_var[2].
if (namespace_info.size() >= 3) {
MS_EXCEPTION(NameError) << namespace_info[2].cast<std::string>();
}
// If the size of namespace_var is less than 3, the default error information is used.
MS_EXCEPTION(NameError) << "The name \'" << value << "\' is not defined.";
}
NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_SYMBOL_STR, namespace_var[0]);
SymbolPtr symbol = std::make_shared<Symbol>(namespace_var[1].cast<std::string>());
NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_SYMBOL_STR, namespace_info[0]);
SymbolPtr symbol = std::make_shared<Symbol>(namespace_info[1].cast<std::string>());
return MakeResolve(name_space, symbol);
}
@ -270,6 +273,7 @@ bool FunctionBlock::CollectRemovablePhi(const ParameterPtr &phi) {
}
AnfNodePtr arg_node = SearchReplaceNode(var, phi);
if (arg_node != nullptr) {
arg_node->set_debug_info(phi->debug_info());
MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " phi " << phi->ToString() << " can be replaced with "
<< arg_node->DebugString();
// Replace var with new one. This equal to statement in TR "v0 is immediately replaced by v1."

View File

@ -372,3 +372,26 @@ def test_call_unsupported_builtin_function_in_if_in_for():
assert "tests/ut/python/pipeline/parse/test_use_undefined_name_or_unsupported_builtin_function.py(364)" in \
str(err.value)
assert "x = divmod(x, i)" in str(err.value)
def test_use_defined_class_obj_in_for():
class Test:
def __init__(self):
self.number = 1
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = [1, 2, 3]
self.test = Test()
def construct(self, x):
for i in self.value:
x = i + self.test.number
ret = x + x
return ret
net = Net()
with pytest.raises(TypeError) as err:
net(Tensor([1, 2, 3], mstype.float32))
assert "Invalid object with type" in str(err.value)