!51324 Enable return list with PyExecute node.
Merge pull request !51324 from LiangZhibo/return_list
This commit is contained in:
commit
54fc5b67ea
|
@ -333,9 +333,15 @@ py::object GetVectorRefOutputDataWithPyExecuteObject(const AnfNodePtr &node, con
|
|||
}
|
||||
|
||||
size_t seq_size = abs_seq->size();
|
||||
// List output will be convert to PyExecute real_node, only need to consider tuple here.
|
||||
py::tuple ret = py::tuple(seq_size);
|
||||
auto real_cnode = real_node->cast<CNodePtr>();
|
||||
if (abs->isa<abstract::AbstractList>()) {
|
||||
py::list ret = py::list(seq_size);
|
||||
for (size_t i = 0; i < seq_size; ++i) {
|
||||
ret[i] = GetVectorRefOutputDataWithPyExecuteObject(real_cnode->input(i + 1), value_seq[i]);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
py::tuple ret = py::tuple(seq_size);
|
||||
for (size_t i = 0; i < seq_size; ++i) {
|
||||
ret[i] = GetVectorRefOutputDataWithPyExecuteObject(real_cnode->input(i + 1), value_seq[i]);
|
||||
}
|
||||
|
|
|
@ -105,7 +105,6 @@ def test_return_constant_list_4():
|
|||
assert np.all(res[2].asnumpy() == np.array([2, 3]))
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="No support yet.")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
|
@ -148,7 +147,6 @@ def test_return_constant_list_6():
|
|||
assert res[2] == 1
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="No support yet.")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
|
@ -227,6 +225,32 @@ def test_return_make_list_node_3():
|
|||
assert res == [Tensor([1]), 1, "a"]
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_return_make_list_node_4():
|
||||
"""
|
||||
Feature: Return list in graph
|
||||
Description: Support return make list node.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@jit
|
||||
def foo(x):
|
||||
x1 = list(x)
|
||||
x2 = {"a": Tensor(5)}
|
||||
x3 = (0, 1.0)
|
||||
return [x1, x2, x3]
|
||||
|
||||
res = foo(Tensor([1, 2, 3]))
|
||||
assert isinstance(res, list)
|
||||
assert len(res) == 3
|
||||
assert res[0] == [Tensor([1]), Tensor([2]), Tensor([3])]
|
||||
assert res[1] == {"a": Tensor(5)}
|
||||
assert res[2] == (0, 1.0)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
|
@ -322,7 +346,6 @@ def test_return_make_list_with_nest_2():
|
|||
assert res == ([Tensor([0]), ([Tensor([0]), 1],)], (Tensor([1]), Tensor([2])))
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="No support yet.")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
|
|
Loading…
Reference in New Issue