forked from mindspore-Ecosystem/mindspore
!38299 Fix graph list comprehension problem and add test cases
Merge pull request !38299 from LiangZhibo/master
This commit is contained in:
commit
29e5cf864f
|
@ -2268,9 +2268,16 @@ FunctionBlockPtr Parser::ParseListCompIter(const FunctionBlockPtr &block, const
|
|||
// Create a header block.
|
||||
MS_EXCEPTION_IF_NULL(block->func_graph());
|
||||
FunctionBlockPtr top_block = GenerateBlock(std::make_shared<TraceListComp>(block->func_graph()->debug_info()));
|
||||
top_block->AddPrevBlock(block);
|
||||
// Handle iter attribute.
|
||||
py::object iter_node = python_adapter::GetPyObjAttr(generator_node, "iter");
|
||||
AnfNodePtr iter_anf_node = ParseExprNode(block, iter_node);
|
||||
MS_EXCEPTION_IF_NULL(iter_anf_node);
|
||||
bool interpret_without_internal =
|
||||
IsPrimitiveCNode(iter_anf_node, prim::kPrimPyInterpret) && !iter_anf_node->interpret_internal_type();
|
||||
if (iter_anf_node->interpret() || interpret_without_internal) {
|
||||
iter_anf_node = ConvertInterpretIterNodeToList(block, iter_anf_node, iter_node);
|
||||
}
|
||||
AnfNodePtr op_iter = top_block->MakeResolveOperation(NAMED_PRIMITIVE_ITER);
|
||||
MS_EXCEPTION_IF_NULL(top_block->func_graph());
|
||||
CNodePtr iter_apply = top_block->func_graph()->NewCNodeInOrder({op_iter, iter_anf_node});
|
||||
|
@ -2278,7 +2285,6 @@ FunctionBlockPtr Parser::ParseListCompIter(const FunctionBlockPtr &block, const
|
|||
// Create header graph.
|
||||
FunctionBlockPtr list_header_block =
|
||||
GenerateBlock(std::make_shared<TraceForHeader>(block->func_graph()->debug_info()));
|
||||
list_header_block->AddPrevBlock(top_block);
|
||||
|
||||
// Create hasNext apply.
|
||||
AnfNodePtr op_hasnext = top_block->MakeResolveOperation(NAMED_PRIMITIVE_HASNEXT);
|
||||
|
@ -2355,9 +2361,9 @@ AnfNodePtr Parser::ParseListCompIfs(const FunctionBlockPtr &list_body_block, con
|
|||
py::object elt_obj = python_adapter::GetPyObjAttr(node, "elt");
|
||||
AnfNodePtr elt_node = ParseExprNode(list_body_block, elt_obj);
|
||||
// Append the element.
|
||||
auto list_append_op = prim::kPrimListAppend;
|
||||
MS_EXCEPTION_IF_NULL(list_body_block->func_graph());
|
||||
auto new_list = list_body_block->func_graph()->NewCNodeInOrder({NewValueNode(list_append_op), list_param, elt_node});
|
||||
auto new_list = list_body_block->func_graph()->NewCNodeInOrder(
|
||||
{NewValueNode(std::make_shared<prim::ListAppend>("ListAppend")), list_param, elt_node});
|
||||
// Return new list in true branch graph.
|
||||
if_true_block->func_graph()->set_output(new_list);
|
||||
|
||||
|
|
|
@ -0,0 +1,165 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test graph list comprehension"""
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor, ms_function, context
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_list_comprehension_with_variable_input():
|
||||
"""
|
||||
Feature: Graph isinstance.
|
||||
Description: Graph list comprehension syntax with variable input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo(a):
|
||||
x = [a for i in range(3)]
|
||||
return x
|
||||
|
||||
res = foo(Tensor([1, 2, 3]))
|
||||
assert len(res) == 3
|
||||
assert np.all(res[0].asnumpy() == np.array([1, 2, 3]))
|
||||
assert np.all(res[1].asnumpy() == np.array([1, 2, 3]))
|
||||
assert np.all(res[2].asnumpy() == np.array([1, 2, 3]))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_list_comprehension_with_variable_input_2():
|
||||
"""
|
||||
Feature: Graph isinstance.
|
||||
Description: Graph list comprehension syntax with variable input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo(a):
|
||||
x = [a + i for i in range(3)]
|
||||
return x
|
||||
|
||||
res = foo(Tensor([1, 2, 3]))
|
||||
assert len(res) == 3
|
||||
assert np.all(res[0].asnumpy() == np.array([1, 2, 3]))
|
||||
assert np.all(res[1].asnumpy() == np.array([2, 3, 4]))
|
||||
assert np.all(res[2].asnumpy() == np.array([3, 4, 5]))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_list_comprehension_with_variable_input_3():
|
||||
"""
|
||||
Feature: Graph isinstance.
|
||||
Description: Graph list comprehension syntax with variable input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo(a):
|
||||
a = a + 10
|
||||
x = [a + i for i in range(3)]
|
||||
return x
|
||||
|
||||
res = foo(Tensor([1, 2, 3]))
|
||||
assert len(res) == 3
|
||||
assert np.all(res[0].asnumpy() == np.array([11, 12, 13]))
|
||||
assert np.all(res[1].asnumpy() == np.array([12, 13, 14]))
|
||||
assert np.all(res[2].asnumpy() == np.array([13, 14, 15]))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_list_comprehension_with_variable_input_and_condition():
|
||||
"""
|
||||
Feature: Graph isinstance.
|
||||
Description: Graph list comprehension syntax with variable input and condition.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo(a):
|
||||
x = [a for i in range(5) if i%2 == 0]
|
||||
return x
|
||||
|
||||
res = foo(Tensor([1, 2, 3]))
|
||||
assert len(res) == 3
|
||||
assert np.all(res[0].asnumpy() == np.array([1, 2, 3]))
|
||||
assert np.all(res[1].asnumpy() == np.array([1, 2, 3]))
|
||||
assert np.all(res[2].asnumpy() == np.array([1, 2, 3]))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_list_comprehension_with_variable_input_and_condition_2():
|
||||
"""
|
||||
Feature: Graph isinstance.
|
||||
Description: Graph list comprehension syntax with variable input and condition.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo(a):
|
||||
x = [a + i for i in range(5) if i%2 == 0]
|
||||
return x
|
||||
|
||||
res = foo(Tensor([1, 2, 3]))
|
||||
assert len(res) == 3
|
||||
assert np.all(res[0].asnumpy() == np.array([1, 2, 3]))
|
||||
assert np.all(res[1].asnumpy() == np.array([3, 4, 5]))
|
||||
assert np.all(res[2].asnumpy() == np.array([5, 6, 7]))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_list_comprehension_with_variable_input_and_condition_3():
|
||||
"""
|
||||
Feature: Graph isinstance.
|
||||
Description: Graph list comprehension syntax with variable input and condition.
|
||||
Expectation: RuntimeError.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo(a):
|
||||
x = [a + i for i in range(5) if P.ReduceSum()(a + i) > 10]
|
||||
return x
|
||||
|
||||
with pytest.raises(RuntimeError) as raise_info:
|
||||
foo(Tensor([1, 2, 3]))
|
||||
assert "Cannot join the return values of different branches" in str(raise_info.value)
|
|
@ -0,0 +1,290 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test graph list comprehension"""
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor, ms_function, context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
def test_list_comprehension_with_local_inputs():
|
||||
"""
|
||||
Feature: Graph isinstance.
|
||||
Description: Graph list comprehension syntax with local input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x = [i for i in range(3)]
|
||||
return Tensor(x)
|
||||
|
||||
res = foo()
|
||||
assert np.all(res.asnumpy() == np.array([0, 1, 2]))
|
||||
|
||||
|
||||
def test_list_comprehension_with_local_inputs_2():
|
||||
"""
|
||||
Feature: Graph isinstance.
|
||||
Description: Graph list comprehension syntax with local input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x = [i + 1 for i in range(3)]
|
||||
return Tensor(x)
|
||||
|
||||
res = foo()
|
||||
assert np.all(res.asnumpy() == np.array([1, 2, 3]))
|
||||
|
||||
|
||||
def test_list_comprehension_with_local_inputs_and_condition():
|
||||
"""
|
||||
Feature: Graph isinstance.
|
||||
Description: Graph list comprehension syntax with local input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x = [i + 1 for i in range(5) if i%2 == 0]
|
||||
return Tensor(x)
|
||||
|
||||
res = foo()
|
||||
assert np.all(res.asnumpy() == np.array([1, 3, 5]))
|
||||
|
||||
|
||||
def test_list_comprehension_with_pre_block_local_input():
|
||||
"""
|
||||
Feature: Graph isinstance.
|
||||
Description: Graph list comprehension syntax with local input from previous block.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
a = 10
|
||||
x = [a for i in range(3)]
|
||||
return Tensor(x)
|
||||
|
||||
res = foo()
|
||||
assert np.all(res.asnumpy() == np.array([10, 10, 10]))
|
||||
|
||||
|
||||
def test_list_comprehension_with_pre_block_local_input_2():
|
||||
"""
|
||||
Feature: Graph isinstance.
|
||||
Description: Graph list comprehension syntax with local input from previous block.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
a = 10
|
||||
x = [a + i for i in range(3)]
|
||||
return Tensor(x)
|
||||
|
||||
res = foo()
|
||||
assert np.all(res.asnumpy() == np.array([10, 11, 12]))
|
||||
|
||||
|
||||
def test_list_comprehension_with_pre_block_local_input_and_condition():
|
||||
"""
|
||||
Feature: Graph isinstance.
|
||||
Description: Graph list comprehension syntax with local input from previous block.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
a = 10
|
||||
x = [a + i for i in range(3) if a > 5]
|
||||
return Tensor(x)
|
||||
|
||||
res = foo()
|
||||
assert np.all(res.asnumpy() == np.array([10, 11, 12]))
|
||||
|
||||
|
||||
def test_list_comprehension_with_pre_block_local_input_and_condition_2():
|
||||
"""
|
||||
Feature: Graph isinstance.
|
||||
Description: Graph list comprehension syntax with local input from previous block.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
a = 10
|
||||
x = [a + i for i in range(5) if a + i < 13]
|
||||
return Tensor(x)
|
||||
|
||||
res = foo()
|
||||
assert np.all(res.asnumpy() == np.array([10, 11, 12]))
|
||||
|
||||
|
||||
def test_list_comprehension_with_numpy_input():
|
||||
"""
|
||||
Feature: Graph isinstance.
|
||||
Description: Graph list comprehension syntax with numpy input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
a = np.array([1, 2, 3])
|
||||
x = [a for i in range(3)]
|
||||
return Tensor(x[0]), Tensor(x[1]), Tensor(x[2])
|
||||
|
||||
res = foo()
|
||||
assert len(res) == 3
|
||||
assert np.all(res[0].asnumpy() == np.array([1, 2, 3]))
|
||||
assert np.all(res[1].asnumpy() == np.array([1, 2, 3]))
|
||||
assert np.all(res[2].asnumpy() == np.array([1, 2, 3]))
|
||||
|
||||
|
||||
def test_list_comprehension_with_numpy_input_2():
|
||||
"""
|
||||
Feature: Graph isinstance.
|
||||
Description: Graph list comprehension syntax with numpy input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
a = np.array([1, 2, 3])
|
||||
x = [a + i for i in range(3)]
|
||||
return Tensor(x[0]), Tensor(x[1]), Tensor(x[2])
|
||||
|
||||
res = foo()
|
||||
assert len(res) == 3
|
||||
assert np.all(res[0].asnumpy() == np.array([1, 2, 3]))
|
||||
assert np.all(res[1].asnumpy() == np.array([2, 3, 4]))
|
||||
assert np.all(res[2].asnumpy() == np.array([3, 4, 5]))
|
||||
|
||||
|
||||
def test_list_comprehension_with_numpy_input_and_condition():
|
||||
"""
|
||||
Feature: Graph isinstance.
|
||||
Description: Graph list comprehension syntax with numpy input and condition.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
a = np.array([1, 2, 3])
|
||||
x = [a for i in range(5) if i%2 == 0]
|
||||
return Tensor(x[0]), Tensor(x[1]), Tensor(x[2])
|
||||
|
||||
res = foo()
|
||||
assert len(res) == 3
|
||||
assert np.all(res[0].asnumpy() == np.array([1, 2, 3]))
|
||||
assert np.all(res[1].asnumpy() == np.array([1, 2, 3]))
|
||||
assert np.all(res[2].asnumpy() == np.array([1, 2, 3]))
|
||||
|
||||
|
||||
def test_list_comprehension_with_numpy_input_and_condition_2():
|
||||
"""
|
||||
Feature: Graph isinstance.
|
||||
Description: Graph list comprehension syntax with numpy input and condition.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
a = np.array([1, 2, 3])
|
||||
x = [a + i for i in range(5) if np.sum(a + i) > 10]
|
||||
return Tensor(x[0]), Tensor(x[1]), Tensor(x[2])
|
||||
|
||||
res = foo()
|
||||
assert len(res) == 3
|
||||
assert np.all(res[0].asnumpy() == np.array([3, 4, 5]))
|
||||
assert np.all(res[1].asnumpy() == np.array([4, 5, 6]))
|
||||
assert np.all(res[2].asnumpy() == np.array([5, 6, 7]))
|
||||
|
||||
|
||||
def test_list_comprehension_with_numpy_input_and_condition_3():
|
||||
"""
|
||||
Feature: Graph isinstance.
|
||||
Description: Graph list comprehension syntax with numpy input and condition.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
a = np.array([1, 2, 3])
|
||||
x = [a + i for i in range(5) if np.sum(a + i) > 20]
|
||||
return x
|
||||
|
||||
res = foo()
|
||||
assert not res
|
||||
|
||||
|
||||
def test_list_comprehension_with_iter_list():
|
||||
"""
|
||||
Feature: Graph isinstance.
|
||||
Description: Graph list comprehension syntax with list as iteration object.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
a = 10
|
||||
m = [1, 2, 3, 4, 5]
|
||||
x = [a + i for i in m if (a + i)%2 == 0]
|
||||
return Tensor(x)
|
||||
|
||||
res = foo()
|
||||
assert np.all(res.asnumpy() == np.array([12, 14]))
|
||||
|
||||
|
||||
def test_list_comprehension_with_iter_list_2():
|
||||
"""
|
||||
Feature: Graph isinstance.
|
||||
Description: Graph list comprehension syntax with list as iteration object.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
a = 10
|
||||
m = np.array([1, 2, 3, 4, 5])
|
||||
x = [a + i for i in m if (a + i)%2 == 0]
|
||||
return Tensor(x)
|
||||
|
||||
res = foo()
|
||||
assert np.all(res.asnumpy() == np.array([12, 14]))
|
||||
|
||||
|
||||
def test_list_comprehension_with_iter_list_3():
|
||||
"""
|
||||
Feature: Graph isinstance.
|
||||
Description: Graph list comprehension syntax with list as iteration object.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
a = 10
|
||||
m = [Tensor([1]), Tensor([2]), Tensor([3])]
|
||||
x = [a + i for i in m]
|
||||
return x[0], x[1], x[2]
|
||||
|
||||
res = foo()
|
||||
assert len(res) == 3
|
||||
assert res[0] == 11
|
||||
assert res[1] == 12
|
||||
assert res[2] == 13
|
Loading…
Reference in New Issue