forked from mindspore-Ecosystem/mindspore
Fix coredump missing return statement after while loop
This commit is contained in:
parent
406ce73515
commit
38083e055a
|
@ -22,7 +22,7 @@ from .parser import (Parser, create_obj_instance, generate_scope,
|
|||
get_dataclass_attributes, get_dataclass_methods, get_obj_id,
|
||||
get_module_namespace, get_obj_type, get_object_key,
|
||||
get_parse_method_of_class, get_scope_name,
|
||||
is_class_member, parse_cb, resolve_symbol, convert_to_ms_tensor)
|
||||
is_class_member, parse_cb, resolve_symbol, convert_to_ms_tensor, get_object_description)
|
||||
from .serialize import *
|
||||
|
||||
__all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol',
|
||||
|
@ -30,4 +30,4 @@ __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class',
|
|||
'get_obj_type', 'get_obj_id', 'create_obj_instance', 'get_module_namespace',
|
||||
'get_class_member_namespace_symbol', 'get_obj_id', 'Parser', 'get_dataclass_attributes',
|
||||
'get_dataclass_methods', 'dump_obj', 'load_obj', 'get_dataclass_methods', 'get_scope_name',
|
||||
'create_slice_obj', 'convert_to_ms_tensor']
|
||||
'create_slice_obj', 'convert_to_ms_tensor', 'get_object_description']
|
||||
|
|
|
@ -322,6 +322,20 @@ def convert_to_ms_tensor(data):
|
|||
return MsTensor(data)
|
||||
|
||||
|
||||
def get_object_description(obj, fname, fline):
|
||||
"""return method or funcition description for error report, include location, class name, etc."""
|
||||
if isinstance(obj, types.MethodType):
|
||||
obj_cls = obj.__self__.__class__
|
||||
class_name = f'{obj_cls.__module__}.{obj_cls.__qualname__}'
|
||||
cls_fname = inspect.getfile(obj_cls)
|
||||
_, cls_fline = inspect.getsourcelines(obj_cls)
|
||||
class_loc = f'{cls_fname}:{cls_fline}'
|
||||
return f"bound method '{obj.__name__}' at {fname}:{fline} of <{class_name} at {class_loc} object>"
|
||||
if isinstance(obj, (types.FunctionType, ast.FunctionDef)):
|
||||
return f"function '{obj.name}' at {fname}:{fline}"
|
||||
return str(obj)
|
||||
|
||||
|
||||
class Parser:
|
||||
"""
|
||||
Parser python code to ast tree.
|
||||
|
|
|
@ -154,6 +154,23 @@ FuncGraphPtr Parser::ParseFuncGraph() {
|
|||
RemoveUnnecessaryPhis();
|
||||
|
||||
MS_EXCEPTION_IF_NULL(pFnBlock);
|
||||
|
||||
// check whether the functions refered by this function and itself are missing 'return' statement
|
||||
auto mng = Manage(pFnBlock->func_graph(), false);
|
||||
for (auto func_graph : mng->func_graphs()) {
|
||||
if (func_graph->get_return() != nullptr) {
|
||||
continue;
|
||||
}
|
||||
py::list ret = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node);
|
||||
py::str desc =
|
||||
python_adapter::CallPyModFn(ast_->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, ast_->function(), ret[0], ret[1]);
|
||||
MS_EXCEPTION(TypeError) << "Missing return statement in " << desc.cast<std::string>() << ".";
|
||||
}
|
||||
// clear manager info after checking missing return
|
||||
for (auto fg : mng->func_graphs()) {
|
||||
fg->ClearAllManagerInfo();
|
||||
}
|
||||
|
||||
return pFnBlock->func_graph();
|
||||
}
|
||||
|
||||
|
@ -271,9 +288,9 @@ FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlo
|
|||
(void)ParseStatements(pFunBlock, funcObj);
|
||||
|
||||
if (current_fg->get_return() == nullptr) {
|
||||
MS_LOG(ERROR) << "Graph return node is null, loc:" << GetLocation(node)->ToString();
|
||||
errcode_ = PARSE_NO_RETURN;
|
||||
return pFunBlock;
|
||||
py::list ret = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node);
|
||||
py::str desc = python_adapter::CallPyModFn(ast_->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, node, ret[0], ret[1]);
|
||||
MS_EXCEPTION(TypeError) << "Missing return statement in " << desc.cast<std::string>() << ".";
|
||||
}
|
||||
GenerateArgsDefaultValueForFunction(pFunBlock, node);
|
||||
return pFunBlock;
|
||||
|
@ -323,7 +340,11 @@ FunctionBlockPtr Parser::ParseStatement(const FunctionBlockPtr &block, const py:
|
|||
}
|
||||
auto filename = location[0].cast<std::string>();
|
||||
auto line_no = location[1].cast<int>();
|
||||
MS_LOG(EXCEPTION) << "Unsupported syntax '" << node_name << "' at " << filename << ":" << line_no;
|
||||
auto fn_loc = block->func_graph()->debug_info()->location();
|
||||
py::str desc = python_adapter::CallPyModFn(ast_->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, ast_->function(),
|
||||
fn_loc->file_name(), fn_loc->line());
|
||||
MS_LOG(EXCEPTION) << "Unsupported syntax '" << node_name << "' at " << filename << ":" << line_no << " in "
|
||||
<< desc.cast<std::string>() << ".";
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -350,7 +371,11 @@ AnfNodePtr Parser::ParseExprNode(const FunctionBlockPtr &block, const py::object
|
|||
py::list ret = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node);
|
||||
auto filename = ret[0].cast<std::string>();
|
||||
auto line_no = ret[1].cast<int>();
|
||||
MS_LOG(EXCEPTION) << "Unsupported syntax '" << node_name << "' at " << filename << ":" << line_no;
|
||||
auto fn_loc = block->func_graph()->debug_info()->location();
|
||||
py::str desc = python_adapter::CallPyModFn(ast_->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, ast_->function(),
|
||||
fn_loc->file_name(), fn_loc->line());
|
||||
MS_LOG(EXCEPTION) << "Unsupported syntax '" << node_name << "' at " << filename << ":" << line_no << " in "
|
||||
<< desc.cast<std::string>() << ".";
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -69,6 +69,7 @@ const char PYTHON_MOD_GET_MODULE_NAMESPACE[] = "get_module_namespace";
|
|||
const char PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL[] = "get_class_member_namespace_symbol";
|
||||
const char PYTHON_MOD_GET_PARSE_METHOD[] = "get_parse_method_of_class";
|
||||
const char PYTHON_MOD_GET_BPROP_METHOD[] = "get_bprop_method_of_class";
|
||||
const char PYTHON_MOD_GET_OBJECT_DESCRIPTION[] = "get_object_description";
|
||||
const char PYTHON_MOD_CONVERT_TO_MS_TENSOR[] = "convert_to_ms_tensor";
|
||||
|
||||
const char PYTHON_PARSE_GET_ARGS[] = "get_args";
|
||||
|
|
|
@ -379,7 +379,11 @@ FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector<AnfNodePtr> &
|
|||
FuncGraphSetPtr func_graphs_to_check = std::make_shared<FuncGraphSet>();
|
||||
while (!nodes_ordered.empty()) {
|
||||
AnfNodePtr node = nodes_ordered.pop();
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node == nullptr) {
|
||||
// Here can not call 'MS_EXCEPTION_IF_NULL' to throw exception, this method may be triggered by desctuctor
|
||||
MS_LOG(WARNING) << "Node to be dropped is nullptr";
|
||||
continue;
|
||||
}
|
||||
if (!all_nodes_.contains(node)) {
|
||||
continue;
|
||||
}
|
||||
|
|
|
@ -96,21 +96,6 @@ TEST_F(TestParser, TestParseGraphSuccess) {
|
|||
ASSERT_TRUE(nullptr != func_graph);
|
||||
}
|
||||
|
||||
TEST_F(TestParser, TestParseGraphFailure) {
|
||||
GetPythonFunction("get_no_return_fn");
|
||||
|
||||
// create parser
|
||||
std::shared_ptr<ParseAst> ast = std::make_shared<ParseAst>(fn);
|
||||
bool succ = ast->InitParseAstInfo();
|
||||
ASSERT_TRUE(succ = true);
|
||||
std::shared_ptr<Parser> parser = std::make_shared<Parser>(ast);
|
||||
|
||||
// parse ast to graph
|
||||
FuncGraphPtr func_graph = parser->ParseFuncGraph();
|
||||
ASSERT_EQ(PARSE_NO_RETURN, parser->errcode());
|
||||
ASSERT_TRUE(nullptr == func_graph);
|
||||
}
|
||||
|
||||
TEST_F(TestParser, TestParseGraphIf) {
|
||||
GetPythonFunction("test_if");
|
||||
|
||||
|
|
|
@ -689,3 +689,26 @@ def test_while_concat():
|
|||
x = Tensor(np.arange(10 * 2 * 3).reshape(10, 2, 3).astype(np.float32))
|
||||
net = Net(x)
|
||||
net(x)
|
||||
|
||||
|
||||
def test_tensor_all_construct_lack_branch():
|
||||
class NetConditionLackBranch(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetConditionLackBranch, self).__init__()
|
||||
self.logicaland = P.LogicalAnd()
|
||||
self.logicalor = P.LogicalOr()
|
||||
|
||||
def construct(self, input1, input2):
|
||||
if input1.all():
|
||||
return self.logicaland(input1, input2)
|
||||
while input1.any():
|
||||
return self.logicalor(input1, input2)
|
||||
# NOTICE: here missing return statement, default return None
|
||||
|
||||
input_np_1 = np.random.choice([True], size=(2, 3, 4, 5))
|
||||
input_tensor_1 = Tensor(input_np_1)
|
||||
input_np_2 = np.random.choice([True, False], size=(2, 3, 4, 5))
|
||||
input_tensor_2 = Tensor(input_np_2)
|
||||
net = NetConditionLackBranch()
|
||||
with pytest.raises(Exception):
|
||||
net(input_tensor_1, input_tensor_2)
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
import functools
|
||||
import logging
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
|
@ -62,13 +63,9 @@ def test_net_without_construct():
|
|||
""" test_net_without_construct """
|
||||
net = NetMissConstruct()
|
||||
inp = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32))
|
||||
try:
|
||||
with pytest.raises(RuntimeError) as err:
|
||||
_executor.compile(net, inp)
|
||||
except RuntimeError as err:
|
||||
if str(err).find("Unsupported syntax 'Raise' at ") >= 0:
|
||||
print(str(err))
|
||||
else:
|
||||
raise err
|
||||
assert "Unsupported syntax 'Raise' at " in str(err.value)
|
||||
|
||||
|
||||
class NetWithRaise(nn.Cell):
|
||||
|
@ -87,13 +84,9 @@ def test_net_with_raise():
|
|||
""" test_net_with_raise """
|
||||
net = NetWithRaise()
|
||||
inp = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32))
|
||||
try:
|
||||
with pytest.raises(RuntimeError) as err:
|
||||
_executor.compile(net, inp)
|
||||
except RuntimeError as err:
|
||||
if str(err).find("Unsupported syntax 'Raise' at ") >= 0:
|
||||
print(str(err))
|
||||
else:
|
||||
raise err
|
||||
assert "Unsupported syntax 'Raise' at " in str(err.value)
|
||||
|
||||
|
||||
class NetAddN(nn.Cell):
|
||||
|
|
|
@ -0,0 +1,201 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
test mindspore grammar constraints
|
||||
1. funtion must have return statement
|
||||
2. raise statement can not be used
|
||||
"""
|
||||
# pylint: disable=R1705, R1710, W0223
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore import dtype as mstype
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
def test_missing_return():
|
||||
class NetMissReturn(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetMissReturn, self).__init__()
|
||||
|
||||
def construct(self, x, y, z):
|
||||
if x == 1:
|
||||
return 10
|
||||
elif x == 20:
|
||||
if y == 1:
|
||||
return 3
|
||||
elif y == 2:
|
||||
for i in range(z):
|
||||
return i + z
|
||||
i = 0
|
||||
while i < z:
|
||||
return i + z
|
||||
def g(u):
|
||||
return x + u
|
||||
# here method 'construct' misses a return statement
|
||||
g(y)
|
||||
else:
|
||||
return 7
|
||||
else:
|
||||
return 5
|
||||
|
||||
net = NetMissReturn()
|
||||
x = Tensor(0, mstype.int32)
|
||||
y = Tensor(5, mstype.int32)
|
||||
z = Tensor(2, mstype.int32)
|
||||
with pytest.raises(TypeError) as er:
|
||||
net(x, y, z)
|
||||
assert "Missing return statement in bound method 'construct'" in str(er.value)
|
||||
|
||||
|
||||
def test_nest_function_missing_return():
|
||||
class NetNestFuncMissReturn(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetNestFuncMissReturn, self).__init__()
|
||||
|
||||
def construct(self, x, y, z):
|
||||
if x == 1:
|
||||
return 10
|
||||
elif x == 20:
|
||||
if y == 1:
|
||||
return 3
|
||||
elif y == 2:
|
||||
for i in range(z):
|
||||
return i + z
|
||||
i = 0
|
||||
while i < z:
|
||||
return i + z
|
||||
def g(u):
|
||||
x += u
|
||||
# nested function 'g' misses a return a statement
|
||||
return g(y)
|
||||
else:
|
||||
return 7
|
||||
else:
|
||||
return 5
|
||||
|
||||
net = NetNestFuncMissReturn()
|
||||
x = Tensor(0, mstype.int32)
|
||||
y = Tensor(5, mstype.int32)
|
||||
z = Tensor(2, mstype.int32)
|
||||
with pytest.raises(TypeError) as er:
|
||||
net(x, y, z)
|
||||
assert "Missing return statement in function 'g'" in str(er.value)
|
||||
|
||||
|
||||
def test_raise_in_method():
|
||||
class NetRaiseInMethod(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetRaiseInMethod, self).__init__()
|
||||
|
||||
def construct(self, x, y, z):
|
||||
if x == 1:
|
||||
return 10
|
||||
elif x == 20:
|
||||
# add not support grammar 'raise' here
|
||||
raise ValueError('Illegal case')
|
||||
else:
|
||||
return y + z
|
||||
|
||||
net = NetRaiseInMethod()
|
||||
x = Tensor(0, mstype.int32)
|
||||
y = Tensor(5, mstype.int32)
|
||||
z = Tensor(2, mstype.int32)
|
||||
with pytest.raises(RuntimeError) as er:
|
||||
net(x, y, z)
|
||||
assert "Unsupported syntax 'Raise' at" in str(er.value)
|
||||
|
||||
|
||||
def test_raise_in_nested_function():
|
||||
class NetNestRaise(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetNestRaise, self).__init__()
|
||||
|
||||
def construct(self, x, y, z):
|
||||
if x == 1:
|
||||
return 10
|
||||
elif x == 20:
|
||||
def nest_fn(u):
|
||||
if u > 0:
|
||||
# add not support grammar 'raise' here
|
||||
raise ValueError('Illegal case')
|
||||
return u + z + 1
|
||||
return nest_fn(y)
|
||||
else:
|
||||
return y + z
|
||||
|
||||
net = NetNestRaise()
|
||||
x = Tensor(0, mstype.int32)
|
||||
y = Tensor(5, mstype.int32)
|
||||
z = Tensor(2, mstype.int32)
|
||||
with pytest.raises(RuntimeError) as er:
|
||||
net(x, y, z)
|
||||
assert "Unsupported syntax 'Raise' at " in str(er.value)
|
||||
|
||||
|
||||
def test_nest_branch_with_return():
|
||||
class NetBranchWithReturn(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetBranchWithReturn, self).__init__()
|
||||
|
||||
def construct(self, x, y, z):
|
||||
if x == 1:
|
||||
return 10
|
||||
else:
|
||||
return 5
|
||||
|
||||
context.set_context(save_graphs=True)
|
||||
net = NetBranchWithReturn()
|
||||
x = Tensor(0, mstype.int32)
|
||||
y = Tensor(5, mstype.int32)
|
||||
z = Tensor(2, mstype.int32)
|
||||
net(x, y, z)
|
||||
|
||||
|
||||
def test_any_with_no_return():
|
||||
class NetAnyNoReturn(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetAnyNoReturn, self).__init__()
|
||||
|
||||
def construct(self, inp):
|
||||
result = inp.any()
|
||||
if result:
|
||||
return 6
|
||||
|
||||
np_input = np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.bool_)
|
||||
tensor = Tensor(np_input)
|
||||
net = NetAnyNoReturn()
|
||||
with pytest.raises(TypeError) as er:
|
||||
net(tensor)
|
||||
assert "Missing return statement in bound method 'construct'" in str(er.value)
|
||||
|
||||
|
||||
def test_missing_construct():
|
||||
class NetMissConstruct(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetMissConstruct, self).__init__()
|
||||
|
||||
def construct1(self, inp):
|
||||
return 5
|
||||
|
||||
np_input = np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.bool_)
|
||||
tensor = Tensor(np_input)
|
||||
net = NetMissConstruct()
|
||||
with pytest.raises(RuntimeError) as er:
|
||||
net(tensor)
|
||||
assert "Unsupported syntax 'Raise' at " in str(er.value)
|
Loading…
Reference in New Issue