!38299 Fix graph list comprehension problem and add test cases

Merge pull request !38299 from LiangZhibo/master
This commit is contained in:
i-robot 2022-07-20 09:44:08 +00:00 committed by Gitee
commit 29e5cf864f
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 464 additions and 3 deletions

View File

@ -2268,9 +2268,16 @@ FunctionBlockPtr Parser::ParseListCompIter(const FunctionBlockPtr &block, const
// Create a header block.
MS_EXCEPTION_IF_NULL(block->func_graph());
FunctionBlockPtr top_block = GenerateBlock(std::make_shared<TraceListComp>(block->func_graph()->debug_info()));
top_block->AddPrevBlock(block);
// Handle iter attribute.
py::object iter_node = python_adapter::GetPyObjAttr(generator_node, "iter");
AnfNodePtr iter_anf_node = ParseExprNode(block, iter_node);
MS_EXCEPTION_IF_NULL(iter_anf_node);
bool interpret_without_internal =
IsPrimitiveCNode(iter_anf_node, prim::kPrimPyInterpret) && !iter_anf_node->interpret_internal_type();
if (iter_anf_node->interpret() || interpret_without_internal) {
iter_anf_node = ConvertInterpretIterNodeToList(block, iter_anf_node, iter_node);
}
AnfNodePtr op_iter = top_block->MakeResolveOperation(NAMED_PRIMITIVE_ITER);
MS_EXCEPTION_IF_NULL(top_block->func_graph());
CNodePtr iter_apply = top_block->func_graph()->NewCNodeInOrder({op_iter, iter_anf_node});
@ -2278,7 +2285,6 @@ FunctionBlockPtr Parser::ParseListCompIter(const FunctionBlockPtr &block, const
// Create header graph.
FunctionBlockPtr list_header_block =
GenerateBlock(std::make_shared<TraceForHeader>(block->func_graph()->debug_info()));
list_header_block->AddPrevBlock(top_block);
// Create hasNext apply.
AnfNodePtr op_hasnext = top_block->MakeResolveOperation(NAMED_PRIMITIVE_HASNEXT);
@ -2355,9 +2361,9 @@ AnfNodePtr Parser::ParseListCompIfs(const FunctionBlockPtr &list_body_block, con
py::object elt_obj = python_adapter::GetPyObjAttr(node, "elt");
AnfNodePtr elt_node = ParseExprNode(list_body_block, elt_obj);
// Append the element.
auto list_append_op = prim::kPrimListAppend;
MS_EXCEPTION_IF_NULL(list_body_block->func_graph());
auto new_list = list_body_block->func_graph()->NewCNodeInOrder({NewValueNode(list_append_op), list_param, elt_node});
auto new_list = list_body_block->func_graph()->NewCNodeInOrder(
{NewValueNode(std::make_shared<prim::ListAppend>("ListAppend")), list_param, elt_node});
// Return new list in true branch graph.
if_true_block->func_graph()->set_output(new_list);

View File

@ -0,0 +1,165 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test graph list comprehension"""
import pytest
import numpy as np
from mindspore import Tensor, ms_function, context
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_list_comprehension_with_variable_input():
"""
Feature: Graph isinstance.
Description: Graph list comprehension syntax with variable input.
Expectation: No exception.
"""
@ms_function
def foo(a):
x = [a for i in range(3)]
return x
res = foo(Tensor([1, 2, 3]))
assert len(res) == 3
assert np.all(res[0].asnumpy() == np.array([1, 2, 3]))
assert np.all(res[1].asnumpy() == np.array([1, 2, 3]))
assert np.all(res[2].asnumpy() == np.array([1, 2, 3]))
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_list_comprehension_with_variable_input_2():
"""
Feature: Graph isinstance.
Description: Graph list comprehension syntax with variable input.
Expectation: No exception.
"""
@ms_function
def foo(a):
x = [a + i for i in range(3)]
return x
res = foo(Tensor([1, 2, 3]))
assert len(res) == 3
assert np.all(res[0].asnumpy() == np.array([1, 2, 3]))
assert np.all(res[1].asnumpy() == np.array([2, 3, 4]))
assert np.all(res[2].asnumpy() == np.array([3, 4, 5]))
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_list_comprehension_with_variable_input_3():
"""
Feature: Graph isinstance.
Description: Graph list comprehension syntax with variable input.
Expectation: No exception.
"""
@ms_function
def foo(a):
a = a + 10
x = [a + i for i in range(3)]
return x
res = foo(Tensor([1, 2, 3]))
assert len(res) == 3
assert np.all(res[0].asnumpy() == np.array([11, 12, 13]))
assert np.all(res[1].asnumpy() == np.array([12, 13, 14]))
assert np.all(res[2].asnumpy() == np.array([13, 14, 15]))
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_list_comprehension_with_variable_input_and_condition():
"""
Feature: Graph isinstance.
Description: Graph list comprehension syntax with variable input and condition.
Expectation: No exception.
"""
@ms_function
def foo(a):
x = [a for i in range(5) if i%2 == 0]
return x
res = foo(Tensor([1, 2, 3]))
assert len(res) == 3
assert np.all(res[0].asnumpy() == np.array([1, 2, 3]))
assert np.all(res[1].asnumpy() == np.array([1, 2, 3]))
assert np.all(res[2].asnumpy() == np.array([1, 2, 3]))
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_list_comprehension_with_variable_input_and_condition_2():
"""
Feature: Graph isinstance.
Description: Graph list comprehension syntax with variable input and condition.
Expectation: No exception.
"""
@ms_function
def foo(a):
x = [a + i for i in range(5) if i%2 == 0]
return x
res = foo(Tensor([1, 2, 3]))
assert len(res) == 3
assert np.all(res[0].asnumpy() == np.array([1, 2, 3]))
assert np.all(res[1].asnumpy() == np.array([3, 4, 5]))
assert np.all(res[2].asnumpy() == np.array([5, 6, 7]))
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_list_comprehension_with_variable_input_and_condition_3():
"""
Feature: Graph isinstance.
Description: Graph list comprehension syntax with variable input and condition.
Expectation: RuntimeError.
"""
@ms_function
def foo(a):
x = [a + i for i in range(5) if P.ReduceSum()(a + i) > 10]
return x
with pytest.raises(RuntimeError) as raise_info:
foo(Tensor([1, 2, 3]))
assert "Cannot join the return values of different branches" in str(raise_info.value)

View File

@ -0,0 +1,290 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test graph list comprehension"""
import numpy as np
from mindspore import Tensor, ms_function, context
context.set_context(mode=context.GRAPH_MODE)
def test_list_comprehension_with_local_inputs():
"""
Feature: Graph isinstance.
Description: Graph list comprehension syntax with local input.
Expectation: No exception.
"""
@ms_function
def foo():
x = [i for i in range(3)]
return Tensor(x)
res = foo()
assert np.all(res.asnumpy() == np.array([0, 1, 2]))
def test_list_comprehension_with_local_inputs_2():
"""
Feature: Graph isinstance.
Description: Graph list comprehension syntax with local input.
Expectation: No exception.
"""
@ms_function
def foo():
x = [i + 1 for i in range(3)]
return Tensor(x)
res = foo()
assert np.all(res.asnumpy() == np.array([1, 2, 3]))
def test_list_comprehension_with_local_inputs_and_condition():
"""
Feature: Graph isinstance.
Description: Graph list comprehension syntax with local input.
Expectation: No exception.
"""
@ms_function
def foo():
x = [i + 1 for i in range(5) if i%2 == 0]
return Tensor(x)
res = foo()
assert np.all(res.asnumpy() == np.array([1, 3, 5]))
def test_list_comprehension_with_pre_block_local_input():
"""
Feature: Graph isinstance.
Description: Graph list comprehension syntax with local input from previous block.
Expectation: No exception.
"""
@ms_function
def foo():
a = 10
x = [a for i in range(3)]
return Tensor(x)
res = foo()
assert np.all(res.asnumpy() == np.array([10, 10, 10]))
def test_list_comprehension_with_pre_block_local_input_2():
"""
Feature: Graph isinstance.
Description: Graph list comprehension syntax with local input from previous block.
Expectation: No exception.
"""
@ms_function
def foo():
a = 10
x = [a + i for i in range(3)]
return Tensor(x)
res = foo()
assert np.all(res.asnumpy() == np.array([10, 11, 12]))
def test_list_comprehension_with_pre_block_local_input_and_condition():
"""
Feature: Graph isinstance.
Description: Graph list comprehension syntax with local input from previous block.
Expectation: No exception.
"""
@ms_function
def foo():
a = 10
x = [a + i for i in range(3) if a > 5]
return Tensor(x)
res = foo()
assert np.all(res.asnumpy() == np.array([10, 11, 12]))
def test_list_comprehension_with_pre_block_local_input_and_condition_2():
"""
Feature: Graph isinstance.
Description: Graph list comprehension syntax with local input from previous block.
Expectation: No exception.
"""
@ms_function
def foo():
a = 10
x = [a + i for i in range(5) if a + i < 13]
return Tensor(x)
res = foo()
assert np.all(res.asnumpy() == np.array([10, 11, 12]))
def test_list_comprehension_with_numpy_input():
"""
Feature: Graph isinstance.
Description: Graph list comprehension syntax with numpy input.
Expectation: No exception.
"""
@ms_function
def foo():
a = np.array([1, 2, 3])
x = [a for i in range(3)]
return Tensor(x[0]), Tensor(x[1]), Tensor(x[2])
res = foo()
assert len(res) == 3
assert np.all(res[0].asnumpy() == np.array([1, 2, 3]))
assert np.all(res[1].asnumpy() == np.array([1, 2, 3]))
assert np.all(res[2].asnumpy() == np.array([1, 2, 3]))
def test_list_comprehension_with_numpy_input_2():
"""
Feature: Graph isinstance.
Description: Graph list comprehension syntax with numpy input.
Expectation: No exception.
"""
@ms_function
def foo():
a = np.array([1, 2, 3])
x = [a + i for i in range(3)]
return Tensor(x[0]), Tensor(x[1]), Tensor(x[2])
res = foo()
assert len(res) == 3
assert np.all(res[0].asnumpy() == np.array([1, 2, 3]))
assert np.all(res[1].asnumpy() == np.array([2, 3, 4]))
assert np.all(res[2].asnumpy() == np.array([3, 4, 5]))
def test_list_comprehension_with_numpy_input_and_condition():
"""
Feature: Graph isinstance.
Description: Graph list comprehension syntax with numpy input and condition.
Expectation: No exception.
"""
@ms_function
def foo():
a = np.array([1, 2, 3])
x = [a for i in range(5) if i%2 == 0]
return Tensor(x[0]), Tensor(x[1]), Tensor(x[2])
res = foo()
assert len(res) == 3
assert np.all(res[0].asnumpy() == np.array([1, 2, 3]))
assert np.all(res[1].asnumpy() == np.array([1, 2, 3]))
assert np.all(res[2].asnumpy() == np.array([1, 2, 3]))
def test_list_comprehension_with_numpy_input_and_condition_2():
"""
Feature: Graph isinstance.
Description: Graph list comprehension syntax with numpy input and condition.
Expectation: No exception.
"""
@ms_function
def foo():
a = np.array([1, 2, 3])
x = [a + i for i in range(5) if np.sum(a + i) > 10]
return Tensor(x[0]), Tensor(x[1]), Tensor(x[2])
res = foo()
assert len(res) == 3
assert np.all(res[0].asnumpy() == np.array([3, 4, 5]))
assert np.all(res[1].asnumpy() == np.array([4, 5, 6]))
assert np.all(res[2].asnumpy() == np.array([5, 6, 7]))
def test_list_comprehension_with_numpy_input_and_condition_3():
"""
Feature: Graph isinstance.
Description: Graph list comprehension syntax with numpy input and condition.
Expectation: No exception.
"""
@ms_function
def foo():
a = np.array([1, 2, 3])
x = [a + i for i in range(5) if np.sum(a + i) > 20]
return x
res = foo()
assert not res
def test_list_comprehension_with_iter_list():
"""
Feature: Graph isinstance.
Description: Graph list comprehension syntax with list as iteration object.
Expectation: No exception.
"""
@ms_function
def foo():
a = 10
m = [1, 2, 3, 4, 5]
x = [a + i for i in m if (a + i)%2 == 0]
return Tensor(x)
res = foo()
assert np.all(res.asnumpy() == np.array([12, 14]))
def test_list_comprehension_with_iter_list_2():
"""
Feature: Graph isinstance.
Description: Graph list comprehension syntax with list as iteration object.
Expectation: No exception.
"""
@ms_function
def foo():
a = 10
m = np.array([1, 2, 3, 4, 5])
x = [a + i for i in m if (a + i)%2 == 0]
return Tensor(x)
res = foo()
assert np.all(res.asnumpy() == np.array([12, 14]))
def test_list_comprehension_with_iter_list_3():
"""
Feature: Graph isinstance.
Description: Graph list comprehension syntax with list as iteration object.
Expectation: No exception.
"""
@ms_function
def foo():
a = 10
m = [Tensor([1]), Tensor([2]), Tensor([3])]
x = [a + i for i in m]
return x[0], x[1], x[2]
res = foo()
assert len(res) == 3
assert res[0] == 11
assert res[1] == 12
assert res[2] == 13