Process the interpret node in the return node

This commit is contained in:
huangbingjian 2021-11-29 17:07:26 +08:00
parent 8bdcc68bb7
commit e970c96d68
3 changed files with 117 additions and 18 deletions

View File

@ -509,14 +509,46 @@ void Parser::MakeConditionBlocks(const FunctionBlockPtr &pre_block, const Functi
false_block->Mature();
}
AnfNodePtr Parser::HandelReturnExprNode(const FunctionBlockPtr &block, const AnfNodePtr &return_expr_node,
const py::object &value_object) {
// The fallback feature is enabled in default.
static const auto use_fallback = (support_fallback() != "0");
if (!use_fallback) {
return return_expr_node;
}
// Handle the case of returning tuple.
py::object obj = python_adapter::GetPyObjAttr(value_object, "elts");
if (!py::isinstance<py::none>(obj)) {
auto elts = py::cast<py::tuple>(obj);
if (!elts.empty()) {
auto cnode = return_expr_node->cast<CNodePtr>();
// The first input of cnode is MakeTuple.
if (cnode->size() != elts.size() + 1) {
MS_LOG(EXCEPTION) << "The size of make_tuple's inputs must be equal to " << (elts.size() + 1) << ".";
}
for (size_t i = 0; i < elts.size(); i++) {
auto input = cnode->input(i + 1);
if (input->interpret()) {
auto interpreted_node = HandleInterpret(block, input, elts[i]);
cnode->set_input(i + 1, interpreted_node);
}
}
return cnode;
}
}
// Handle the case of a single return value.
return HandleInterpret(block, return_expr_node, value_object);
}
FunctionBlockPtr Parser::ParseReturn(const FunctionBlockPtr &block, const py::object &node) {
MS_LOG(DEBUG) << "Process ast return";
MS_EXCEPTION_IF_NULL(block);
// Parse the return Statements value.
py::object value_object = python_adapter::GetPyObjAttr(node, "value");
AnfNodePtr return_expr_node = ParseExprNode(block, value_object);
// Check if need interpreting.
return_expr_node = HandleInterpret(block, return_expr_node, value_object);
return_expr_node = HandelReturnExprNode(block, return_expr_node, value_object);
// Create the `return` CNode.
auto func_graph = block->func_graph();
CNodePtr return_cnode = func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimReturn), return_expr_node});

View File

@ -224,6 +224,10 @@ class Parser {
// Assign value to subscript
void HandleAssignSubscript(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node);
// Interpret the return node.
AnfNodePtr HandelReturnExprNode(const FunctionBlockPtr &block, const AnfNodePtr &return_expr_node,
const py::object &value_object);
// Process a bool operation value list
AnfNodePtr ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, AstSubType mode);

View File

@ -378,7 +378,11 @@ def test_np_squeeze():
assert np.all(np_squeeze().asnumpy() == np.array([[0, 1], [2, 3]]))
@pytest.mark.skip(reason='Not support graph fallback feature yet')
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_np_concat():
"""
Feature: JIT Fallback
@ -402,7 +406,11 @@ def test_np_concat():
assert np.all(out_vstack.asnumpy() == np.array([[1, 2], [3, 4], [5, 6], [7, 8]]))
@pytest.mark.skip(reason='Not support graph fallback feature yet')
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_np_split():
"""
Feature: JIT Fallback
@ -423,7 +431,11 @@ def test_np_split():
assert np.all(out_vsplit.asnumpy() == np.array([[[0, 1]], [[2, 3]]]))
@pytest.mark.skip(reason='Not support graph fallback feature yet')
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_np_element():
"""
Feature: JIT Fallback
@ -447,7 +459,11 @@ def test_np_element():
assert np.all(out_unique.asnumpy() == np.array([2, 5, 6, 7, 8, 9]))
@pytest.mark.skip(reason='Not support graph fallback feature yet')
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_np_bitwise():
"""
Feature: JIT Fallback
@ -471,7 +487,11 @@ def test_np_bitwise():
assert right_shift.asnumpy() == 10
@pytest.mark.skip(reason='Not support graph fallback feature yet')
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_np_char_1():
"""
Feature: JIT Fallback
@ -500,7 +520,11 @@ def test_np_char_1():
assert char_upper.asnumpy() == 'FALLBACK'
@pytest.mark.skip(reason='Not support graph fallback feature yet')
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_np_char_2():
"""
Feature: JIT Fallback
@ -529,7 +553,11 @@ def test_np_char_2():
assert char_decode.asnumpy() == 'runoob'
@pytest.mark.skip(reason='Not support graph fallback feature yet')
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_np_degree():
"""
Feature: JIT Fallback
@ -556,7 +584,11 @@ def test_np_degree():
assert np.isclose(out_arctan.asnumpy(), 45)
@pytest.mark.skip(reason='Not support graph fallback feature yet')
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_np_math_1():
"""
Feature: JIT Fallback
@ -585,7 +617,11 @@ def test_np_math_1():
assert np.all(out_remainder.asnumpy() == np.array([0, 2]))
@pytest.mark.skip(reason='Not support graph fallback feature yet')
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_np_math_2():
"""
Feature: JIT Fallback
@ -610,7 +646,11 @@ def test_np_math_2():
assert np.allclose(out_power.asnumpy(), np.array([1, 4, 9]))
@pytest.mark.skip(reason='Not support graph fallback feature yet')
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_np_statistic():
"""
Feature: JIT Fallback
@ -644,7 +684,11 @@ def test_np_statistic():
assert np.isclose(out_var.asnumpy(), 2.0)
@pytest.mark.skip(reason='Not support graph fallback feature yet')
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_np_sort():
"""
Feature: JIT Fallback
@ -660,22 +704,41 @@ def test_np_sort():
out_argmin = np.argmin(x)
out_nonzero = np.nonzero(x)
out_where = np.where(x > 4)
condition = x % 2 == 0
out_extract = np.extract(condition, x)
return Tensor(out_sort), Tensor(out_argsort), Tensor(out_argmax), \
Tensor(out_argmin), Tensor(out_nonzero), Tensor(out_where), Tensor(out_extract)
Tensor(out_argmin), Tensor(out_nonzero), Tensor(out_where)
out_sort, out_argsort, out_argmax, out_argmin, out_nonzero, out_where, out_extract = np_sort()
out_sort, out_argsort, out_argmax, out_argmin, out_nonzero, out_where = np_sort()
assert np.all(out_sort.asnumpy() == np.array([1, 2, 3, 4, 5]))
assert np.all(out_argsort.asnumpy() == np.array([1, 2, 0, 3, 4]))
assert out_argmax.asnumpy() == 4
assert out_argmin.asnumpy() == 1
assert np.all(out_nonzero.asnumpy() == np.array([0, 1, 2, 3, 4]))
assert np.all(out_where.asnumpy() == np.array([4]))
assert np.all(out_extract.asnumpy() == np.array([2, 4]))
@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_np_extract():
"""
Feature: JIT Fallback
Description: Test numpy extract method in graph mode.
Expectation: No exception.
"""
@ms_function
def np_extract():
x = np.array([3, 1, 2, 4, 5])
condition = x % 2 == 0
out_extract = np.extract(condition, x)
return Tensor(out_extract)
out_extract = np_extract()
assert np.all(out_extract.asnumpy() == np.array([2, 4]))
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_np_matrix():
"""
Feature: JIT Fallback