!51324 Enable return list with PyExecute node.

Merge pull request !51324 from LiangZhibo/return_list
This commit is contained in:
i-robot 2023-03-25 03:44:43 +00:00 committed by Gitee
commit 54fc5b67ea
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 34 additions and 5 deletions

View File

@ -333,9 +333,15 @@ py::object GetVectorRefOutputDataWithPyExecuteObject(const AnfNodePtr &node, con
}
size_t seq_size = abs_seq->size();
// List output will be convert to PyExecute real_node, only need to consider tuple here.
py::tuple ret = py::tuple(seq_size);
auto real_cnode = real_node->cast<CNodePtr>();
if (abs->isa<abstract::AbstractList>()) {
py::list ret = py::list(seq_size);
for (size_t i = 0; i < seq_size; ++i) {
ret[i] = GetVectorRefOutputDataWithPyExecuteObject(real_cnode->input(i + 1), value_seq[i]);
}
return ret;
}
py::tuple ret = py::tuple(seq_size);
for (size_t i = 0; i < seq_size; ++i) {
ret[i] = GetVectorRefOutputDataWithPyExecuteObject(real_cnode->input(i + 1), value_seq[i]);
}

View File

@ -105,7 +105,6 @@ def test_return_constant_list_4():
assert np.all(res[2].asnumpy() == np.array([2, 3]))
@pytest.mark.skip(reason="No support yet.")
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@ -148,7 +147,6 @@ def test_return_constant_list_6():
assert res[2] == 1
@pytest.mark.skip(reason="No support yet.")
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@ -227,6 +225,32 @@ def test_return_make_list_node_3():
assert res == [Tensor([1]), 1, "a"]
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_return_make_list_node_4():
"""
Feature: Return list in graph
Description: Support return make list node.
Expectation: No exception.
"""
@jit
def foo(x):
x1 = list(x)
x2 = {"a": Tensor(5)}
x3 = (0, 1.0)
return [x1, x2, x3]
res = foo(Tensor([1, 2, 3]))
assert isinstance(res, list)
assert len(res) == 3
assert res[0] == [Tensor([1]), Tensor([2]), Tensor([3])]
assert res[1] == {"a": Tensor(5)}
assert res[2] == (0, 1.0)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@ -322,7 +346,6 @@ def test_return_make_list_with_nest_2():
assert res == ([Tensor([0]), ([Tensor([0]), 1],)], (Tensor([1]), Tensor([2])))
@pytest.mark.skip(reason="No support yet.")
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training