forked from mindspore-Ecosystem/mindspore
!26962 [Fallback] Process the interpret node in the return node
Merge pull request !26962 from huangbingjian/fallback_return_tuple
This commit is contained in:
commit
c621547d26
|
@ -509,14 +509,46 @@ void Parser::MakeConditionBlocks(const FunctionBlockPtr &pre_block, const Functi
|
|||
false_block->Mature();
|
||||
}
|
||||
|
||||
AnfNodePtr Parser::HandelReturnExprNode(const FunctionBlockPtr &block, const AnfNodePtr &return_expr_node,
|
||||
const py::object &value_object) {
|
||||
// The fallback feature is enabled in default.
|
||||
static const auto use_fallback = (support_fallback() != "0");
|
||||
if (!use_fallback) {
|
||||
return return_expr_node;
|
||||
}
|
||||
|
||||
// Handle the case of returning tuple.
|
||||
py::object obj = python_adapter::GetPyObjAttr(value_object, "elts");
|
||||
if (!py::isinstance<py::none>(obj)) {
|
||||
auto elts = py::cast<py::tuple>(obj);
|
||||
if (!elts.empty()) {
|
||||
auto cnode = return_expr_node->cast<CNodePtr>();
|
||||
// The first input of cnode is MakeTuple.
|
||||
if (cnode->size() != elts.size() + 1) {
|
||||
MS_LOG(EXCEPTION) << "The size of make_tuple's inputs must be equal to " << (elts.size() + 1) << ".";
|
||||
}
|
||||
for (size_t i = 0; i < elts.size(); i++) {
|
||||
auto input = cnode->input(i + 1);
|
||||
if (input->interpret()) {
|
||||
auto interpreted_node = HandleInterpret(block, input, elts[i]);
|
||||
cnode->set_input(i + 1, interpreted_node);
|
||||
}
|
||||
}
|
||||
return cnode;
|
||||
}
|
||||
}
|
||||
|
||||
// Handle the case of a single return value.
|
||||
return HandleInterpret(block, return_expr_node, value_object);
|
||||
}
|
||||
|
||||
FunctionBlockPtr Parser::ParseReturn(const FunctionBlockPtr &block, const py::object &node) {
|
||||
MS_LOG(DEBUG) << "Process ast return";
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
// Parse the return Statements value.
|
||||
py::object value_object = python_adapter::GetPyObjAttr(node, "value");
|
||||
AnfNodePtr return_expr_node = ParseExprNode(block, value_object);
|
||||
// Check if need interpreting.
|
||||
return_expr_node = HandleInterpret(block, return_expr_node, value_object);
|
||||
return_expr_node = HandelReturnExprNode(block, return_expr_node, value_object);
|
||||
// Create the `return` CNode.
|
||||
auto func_graph = block->func_graph();
|
||||
CNodePtr return_cnode = func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimReturn), return_expr_node});
|
||||
|
|
|
@ -224,6 +224,10 @@ class Parser {
|
|||
// Assign value to subscript
|
||||
void HandleAssignSubscript(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node);
|
||||
|
||||
// Interpret the return node.
|
||||
AnfNodePtr HandelReturnExprNode(const FunctionBlockPtr &block, const AnfNodePtr &return_expr_node,
|
||||
const py::object &value_object);
|
||||
|
||||
// Process a bool operation value list
|
||||
AnfNodePtr ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, AstSubType mode);
|
||||
|
||||
|
|
|
@ -378,7 +378,11 @@ def test_np_squeeze():
|
|||
assert np.all(np_squeeze().asnumpy() == np.array([[0, 1], [2, 3]]))
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_np_concat():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
@ -402,7 +406,11 @@ def test_np_concat():
|
|||
assert np.all(out_vstack.asnumpy() == np.array([[1, 2], [3, 4], [5, 6], [7, 8]]))
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_np_split():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
@ -423,7 +431,11 @@ def test_np_split():
|
|||
assert np.all(out_vsplit.asnumpy() == np.array([[[0, 1]], [[2, 3]]]))
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_np_element():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
@ -447,7 +459,11 @@ def test_np_element():
|
|||
assert np.all(out_unique.asnumpy() == np.array([2, 5, 6, 7, 8, 9]))
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_np_bitwise():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
@ -471,7 +487,11 @@ def test_np_bitwise():
|
|||
assert right_shift.asnumpy() == 10
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_np_char_1():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
@ -500,7 +520,11 @@ def test_np_char_1():
|
|||
assert char_upper.asnumpy() == 'FALLBACK'
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_np_char_2():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
@ -529,7 +553,11 @@ def test_np_char_2():
|
|||
assert char_decode.asnumpy() == 'runoob'
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_np_degree():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
@ -556,7 +584,11 @@ def test_np_degree():
|
|||
assert np.isclose(out_arctan.asnumpy(), 45)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_np_math_1():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
@ -585,7 +617,11 @@ def test_np_math_1():
|
|||
assert np.all(out_remainder.asnumpy() == np.array([0, 2]))
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_np_math_2():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
@ -610,7 +646,11 @@ def test_np_math_2():
|
|||
assert np.allclose(out_power.asnumpy(), np.array([1, 4, 9]))
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_np_statistic():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
@ -644,7 +684,11 @@ def test_np_statistic():
|
|||
assert np.isclose(out_var.asnumpy(), 2.0)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_np_sort():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
@ -660,22 +704,41 @@ def test_np_sort():
|
|||
out_argmin = np.argmin(x)
|
||||
out_nonzero = np.nonzero(x)
|
||||
out_where = np.where(x > 4)
|
||||
condition = x % 2 == 0
|
||||
out_extract = np.extract(condition, x)
|
||||
return Tensor(out_sort), Tensor(out_argsort), Tensor(out_argmax), \
|
||||
Tensor(out_argmin), Tensor(out_nonzero), Tensor(out_where), Tensor(out_extract)
|
||||
Tensor(out_argmin), Tensor(out_nonzero), Tensor(out_where)
|
||||
|
||||
out_sort, out_argsort, out_argmax, out_argmin, out_nonzero, out_where, out_extract = np_sort()
|
||||
out_sort, out_argsort, out_argmax, out_argmin, out_nonzero, out_where = np_sort()
|
||||
assert np.all(out_sort.asnumpy() == np.array([1, 2, 3, 4, 5]))
|
||||
assert np.all(out_argsort.asnumpy() == np.array([1, 2, 0, 3, 4]))
|
||||
assert out_argmax.asnumpy() == 4
|
||||
assert out_argmin.asnumpy() == 1
|
||||
assert np.all(out_nonzero.asnumpy() == np.array([0, 1, 2, 3, 4]))
|
||||
assert np.all(out_where.asnumpy() == np.array([4]))
|
||||
assert np.all(out_extract.asnumpy() == np.array([2, 4]))
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
def test_np_extract():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test numpy extract method in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def np_extract():
|
||||
x = np.array([3, 1, 2, 4, 5])
|
||||
condition = x % 2 == 0
|
||||
out_extract = np.extract(condition, x)
|
||||
return Tensor(out_extract)
|
||||
|
||||
out_extract = np_extract()
|
||||
assert np.all(out_extract.asnumpy() == np.array([2, 4]))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_np_matrix():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
|
Loading…
Reference in New Issue