Add the check of function return None.

This commit is contained in:
Margaret_wangrui 2021-11-24 09:05:47 +08:00
parent 66b9de8e70
commit f9a384456a
6 changed files with 25 additions and 19 deletions

View File

@ -480,6 +480,16 @@ FunctionBlockPtr Parser::ParseReturn(const FunctionBlockPtr &block, const py::ob
return_expr_node = HandleInterpret(block, return_expr_node, value_object);
// Create the `return` CNode.
auto func_graph = block->func_graph();
if (IsValueNode<None>(return_expr_node)) {
py::list ret = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node);
const auto min_list_size = 2;
if (ret.size() < min_list_size) {
MS_LOG(EXCEPTION) << "list size:" << ret.size() << " is less than 2.";
}
py::str desc =
python_adapter::CallPyModFn(ast_->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, ast_->function(), ret[0], ret[1]);
MS_EXCEPTION(TypeError) << "Function should not 'Return None', is located in:" << desc.cast<std::string>();
}
CNodePtr return_cnode = func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimReturn), return_expr_node});
func_graph->set_return(return_cnode);
return block;

View File

@ -146,9 +146,6 @@ TEST_F(TestParser, TestParseGraphNamedConst) {
GetPythonFunction("testDoNamedConstFalse");
ret_val = ParsePythonCode(fn);
ASSERT_TRUE(nullptr != ret_val);
GetPythonFunction("testDoNamedConstNone");
ret_val = ParsePythonCode(fn);
ASSERT_TRUE(nullptr != ret_val);
}
TEST_F(TestParser, TestParseGraphForStatement) {

View File

@ -221,10 +221,6 @@ def testDoNamedConstFalse():
return False
def testDoNamedConstNone():
return None
# Test_Class_type
@dataclass
class TestFoo:

View File

@ -61,9 +61,10 @@ class NetMissConstruct(nn.Cell):
def test_net_without_construct():
""" test_net_without_construct """
net = NetMissConstruct()
inp = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32))
_cell_graph_executor.compile(net, inp)
with pytest.raises(TypeError, match="Function should not 'Return None'"):
net = NetMissConstruct()
inp = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32))
_cell_graph_executor.compile(net, inp)
class NetWithRaise(nn.Cell):

View File

@ -192,7 +192,8 @@ def test_missing_construct():
def construct1(self, inp):
return 5
np_input = np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.bool_)
tensor = Tensor(np_input)
net = NetMissConstruct()
assert net(tensor) is None
with pytest.raises(TypeError, match="Function should not 'Return None'"):
np_input = np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.bool_)
tensor = Tensor(np_input)
net = NetMissConstruct()
assert net(tensor) is None

View File

@ -14,7 +14,7 @@
# ============================================================================
""" test super"""
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
@ -104,10 +104,11 @@ def test_mul_super():
def test_super_cell():
net = Net(2)
x = Tensor(np.ones([1, 2, 3], np.int32))
y = Tensor(np.ones([1, 2, 3], np.int32))
assert net(x, y) is None
with pytest.raises(TypeError, match="Function should not 'Return None'"):
net = Net(2)
x = Tensor(np.ones([1, 2, 3], np.int32))
y = Tensor(np.ones([1, 2, 3], np.int32))
assert net(x, y) is None
def test_single_super_in():