Add the check of function return None.
This commit is contained in:
parent
66b9de8e70
commit
f9a384456a
|
@ -480,6 +480,16 @@ FunctionBlockPtr Parser::ParseReturn(const FunctionBlockPtr &block, const py::ob
|
|||
return_expr_node = HandleInterpret(block, return_expr_node, value_object);
|
||||
// Create the `return` CNode.
|
||||
auto func_graph = block->func_graph();
|
||||
if (IsValueNode<None>(return_expr_node)) {
|
||||
py::list ret = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node);
|
||||
const auto min_list_size = 2;
|
||||
if (ret.size() < min_list_size) {
|
||||
MS_LOG(EXCEPTION) << "list size:" << ret.size() << " is less than 2.";
|
||||
}
|
||||
py::str desc =
|
||||
python_adapter::CallPyModFn(ast_->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, ast_->function(), ret[0], ret[1]);
|
||||
MS_EXCEPTION(TypeError) << "Function should not 'Return None', is located in:" << desc.cast<std::string>();
|
||||
}
|
||||
CNodePtr return_cnode = func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimReturn), return_expr_node});
|
||||
func_graph->set_return(return_cnode);
|
||||
return block;
|
||||
|
|
|
@ -146,9 +146,6 @@ TEST_F(TestParser, TestParseGraphNamedConst) {
|
|||
GetPythonFunction("testDoNamedConstFalse");
|
||||
ret_val = ParsePythonCode(fn);
|
||||
ASSERT_TRUE(nullptr != ret_val);
|
||||
GetPythonFunction("testDoNamedConstNone");
|
||||
ret_val = ParsePythonCode(fn);
|
||||
ASSERT_TRUE(nullptr != ret_val);
|
||||
}
|
||||
|
||||
TEST_F(TestParser, TestParseGraphForStatement) {
|
||||
|
|
|
@ -221,10 +221,6 @@ def testDoNamedConstFalse():
|
|||
return False
|
||||
|
||||
|
||||
def testDoNamedConstNone():
|
||||
return None
|
||||
|
||||
|
||||
# Test_Class_type
|
||||
@dataclass
|
||||
class TestFoo:
|
||||
|
|
|
@ -61,9 +61,10 @@ class NetMissConstruct(nn.Cell):
|
|||
|
||||
def test_net_without_construct():
|
||||
""" test_net_without_construct """
|
||||
net = NetMissConstruct()
|
||||
inp = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32))
|
||||
_cell_graph_executor.compile(net, inp)
|
||||
with pytest.raises(TypeError, match="Function should not 'Return None'"):
|
||||
net = NetMissConstruct()
|
||||
inp = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32))
|
||||
_cell_graph_executor.compile(net, inp)
|
||||
|
||||
|
||||
class NetWithRaise(nn.Cell):
|
||||
|
|
|
@ -192,7 +192,8 @@ def test_missing_construct():
|
|||
def construct1(self, inp):
|
||||
return 5
|
||||
|
||||
np_input = np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.bool_)
|
||||
tensor = Tensor(np_input)
|
||||
net = NetMissConstruct()
|
||||
assert net(tensor) is None
|
||||
with pytest.raises(TypeError, match="Function should not 'Return None'"):
|
||||
np_input = np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.bool_)
|
||||
tensor = Tensor(np_input)
|
||||
net = NetMissConstruct()
|
||||
assert net(tensor) is None
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# ============================================================================
|
||||
""" test super"""
|
||||
import numpy as np
|
||||
|
||||
import pytest
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
|
@ -104,10 +104,11 @@ def test_mul_super():
|
|||
|
||||
|
||||
def test_super_cell():
|
||||
net = Net(2)
|
||||
x = Tensor(np.ones([1, 2, 3], np.int32))
|
||||
y = Tensor(np.ones([1, 2, 3], np.int32))
|
||||
assert net(x, y) is None
|
||||
with pytest.raises(TypeError, match="Function should not 'Return None'"):
|
||||
net = Net(2)
|
||||
x = Tensor(np.ones([1, 2, 3], np.int32))
|
||||
y = Tensor(np.ones([1, 2, 3], np.int32))
|
||||
assert net(x, y) is None
|
||||
|
||||
|
||||
def test_single_super_in():
|
||||
|
|
Loading…
Reference in New Issue