!36239 Enable syntax of "for x in xs", when xs is an interpret node

Merge pull request !36239 from LiangZhibo/for
This commit is contained in:
i-robot 2022-06-22 01:08:52 +00:00 committed by Gitee
commit c39a0b5969
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
9 changed files with 148 additions and 47 deletions

View File

@ -1945,6 +1945,38 @@ FunctionBlockPtr Parser::ParseFor(const FunctionBlockPtr &block, const py::objec
return ParseForUnroll(block, node);
}
AnfNodePtr Parser::ConvertInterpretIterNodeToList(const FunctionBlockPtr &block, const AnfNodePtr &iter_node,
const py::object iter_obj) {
// For interpret iter_node, convert it to list. xs --> list(xs).
py::object iter_id = python_adapter::GetPyObjAttr(iter_obj, "id");
if (!py::isinstance<py::none>(iter_id)) {
// If variable is assigned, for example:
// xs = np.array([1, 2, 3, 4])
// for x in xs
const std::string &iter_id_str = iter_id.cast<py::str>();
return MakeInterpretNode(block, iter_node, "list(" + iter_id_str + ")");
}
// If variable is not assigned, for example:
// for x in np.array([1, 2, 3, 4])
const auto &interpret_iter_node =
IsPrimitiveCNode(iter_node, prim::kPrimPyInterpret) ? iter_node : HandleInterpret(block, iter_node, iter_obj);
constexpr size_t script_index = 1;
auto iter_cnode = interpret_iter_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(iter_cnode);
auto iter_cnode_inputs = iter_cnode->inputs();
auto iter_script_input = iter_cnode_inputs[script_index];
if (!IsValueNode<Script>(iter_script_input)) {
MS_LOG(EXCEPTION) << "The second input to iter node: " << interpret_iter_node->DebugString()
<< " should be a script value node but got: " << iter_script_input->DebugString() << ".";
}
auto script = iter_script_input->cast<ValueNodePtr>();
auto script_val = script->value()->cast<ScriptPtr>();
auto script_text = script_val->script();
auto new_script_val = NewValueNode(std::make_shared<Script>("list(" + script_text + ")"));
iter_cnode_inputs[script_index] = new_script_val;
return block->func_graph()->NewCNodeInOrder(iter_cnode_inputs);
}
// Implement unroll for statement with tuple/getitem.
FunctionBlockPtr Parser::ParseForUnroll(const FunctionBlockPtr &block, const py::object &node) {
MS_LOG(DEBUG) << "Process ast For by loop variable";
@ -1962,9 +1994,7 @@ FunctionBlockPtr Parser::ParseForUnroll(const FunctionBlockPtr &block, const py:
bool interpret_without_internal =
IsPrimitiveCNode(iter_node, prim::kPrimPyInterpret) && !iter_node->interpret_internal_type();
if (iter_node->interpret() || interpret_without_internal) {
MS_EXCEPTION(TypeError) << "Parsing syntax 'for x in xs', xs can not be interpret node but got "
<< iter_node->DebugString()
<< ".\nNodeInfo: " << trace::GetDebugInfo(iter_node->debug_info());
iter_node = ConvertInterpretIterNodeToList(block, iter_node, iter_obj);
}
// Generate node for loop count and convert it to tensor, to make the loop not unroll
CNodePtr scalar_len = block->func_graph()->NewCNodeInOrder({op_len, iter_node});

View File

@ -237,6 +237,9 @@ class Parser {
void UpdateInterpretForUserNode(const AnfNodePtr &user_node, const std::vector<AnfNodePtr> &nodes);
// Make interpret node.
AnfNodePtr MakeInterpretNode(const FunctionBlockPtr &block, const AnfNodePtr &value_node, const string &script_text);
// Convert interpret iter node to list.
AnfNodePtr ConvertInterpretIterNodeToList(const FunctionBlockPtr &block, const AnfNodePtr &iter_node,
const py::object iter_obj);
// Check if the node need interpreting.
AnfNodePtr HandleInterpret(const FunctionBlockPtr &block, const AnfNodePtr &value_node,
const py::object &value_object);

View File

@ -93,7 +93,7 @@ class Script final : public Named {
~Script() override = default;
MS_DECLARE_PARENT(Script, Named);
std::string script() { return script_; }
std::string script() const { return script_; }
abstract::AbstractBasePtr ToAbstract() override {
return std::make_shared<abstract::AbstractScript>(shared_from_base<Script>());
}

View File

@ -0,0 +1,52 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test graph fallback control flow."""
import pytest
import numpy as np
from mindspore import Tensor, ms_function, context
context.set_context(mode=context.GRAPH_MODE)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_for_after_for_in_if_3():
"""
Feature: JIT Fallback
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
def func3303():
x = np.array([1, 2, 3])
y = np.array([5, 6, 7])
k = []
if x[2] < y[0]:
y = y - x
for i in y:
k.append(i)
z = Tensor(k)
out = 1
for i in z:
out = out * i
return out
res = func3303()
assert res == 64

View File

@ -121,6 +121,58 @@ def test_single_for_x_in_xs():
for i in x:
y += i
return Tensor(y)
with pytest.raises(TypeError) as err:
res = control_flow_for()
assert np.allclose(res.asnumpy(), 3.3)
def test_single_for_x_in_xs_2():
"""
Feature: JIT Fallback
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
def control_flow_for():
y = np.array(0)
for i in np.array([1.1, 2.2]):
y += i
return Tensor(y)
res = control_flow_for()
assert np.allclose(res.asnumpy(), 3.3)
def test_single_for_wrong_xs():
"""
Feature: JIT Fallback
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
def control_flow_for():
y = np.array(0)
for i in np.int64(1):
y += i
return Tensor(y)
with pytest.raises(TypeError) as info:
control_flow_for()
assert "Parsing syntax 'for x in xs', xs can not be interpret node " in str(err.value)
assert "object is not iterable" in str(info.value)
def test_single_for_wrong_xs_2():
"""
Feature: JIT Fallback
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
def control_flow_for():
x = np.int64(1)
y = np.array(0)
for i in x:
y += i
return Tensor(y)
with pytest.raises(TypeError) as info:
control_flow_for()
assert "object is not iterable" in str(info.value)

View File

@ -13,7 +13,6 @@
# limitations under the License.
# ============================================================================
""" test graph fallback control flow for after if in if scenario"""
import pytest
import numpy as np
from mindspore import Tensor, ms_function, context
@ -94,7 +93,6 @@ def test_for_after_if_in_if_numpy():
assert np.all(res.asnumpy() == np.array([23, 24, 25]))
@pytest.mark.skip(reason='Not support to get attribute for InterpretObject.')
def test_for_after_if_in_if_numpy_2():
"""
Feature: JIT Fallback

View File

@ -13,7 +13,6 @@
# limitations under the License.
# ============================================================================
""" test graph fallback control flow for after if in while scenario"""
import pytest
import numpy as np
from mindspore import Tensor, ms_function, context
@ -42,7 +41,6 @@ def test_for_after_if_in_while_numpy():
assert res == -4
@pytest.mark.skip(reason='Not support to get attribute for InterpretObject.')
def test_for_after_if_in_while_numpy_2():
"""
Feature: JIT Fallback
@ -62,4 +60,4 @@ def test_for_after_if_in_while_numpy_2():
y += i - 20
return sum(y)
res = control_flow_for_after_if_in_while()
assert res == 18
assert res == -222

View File

@ -13,7 +13,6 @@
# limitations under the License.
# ============================================================================
""" test graph fallback control flow for after if in for scenario"""
import pytest
import numpy as np
from mindspore import Tensor, ms_function, context
@ -66,7 +65,6 @@ def test_for_after_if_in_for_tensor_2():
assert res == 1
@pytest.mark.skip(reason='Not support to get attribute for InterpretObject.')
def test_for_after_if_in_for_numpy():
"""
Feature: JIT Fallback
@ -82,14 +80,13 @@ def test_for_after_if_in_for_numpy():
y += 1
if sum(x) > 15:
break
for _ in y:
x += y
for i in y:
x += i
return Tensor(max(x))
res = control_flow_for_after_if_in_for()
assert res == 11
@pytest.mark.skip(reason='Not support to get attribute for InterpretObject.')
def test_for_after_if_in_for_numpy_2():
"""
Feature: JIT Fallback
@ -110,6 +107,6 @@ def test_for_after_if_in_for_numpy_2():
y += 2
for i in range(3):
a += y[i]
return Tensor(max(x))
return Tensor(a)
res = control_flow_for_after_if_in_for()
assert res == 17
assert res == 26

View File

@ -13,7 +13,6 @@
# limitations under the License.
# ============================================================================
""" test graph fallback control flow."""
import pytest
import numpy as np
from mindspore import Tensor, ms_function, context
@ -70,34 +69,6 @@ def test_for_after_for_in_if_2():
assert res_y == 4
@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_for_after_for_in_if_3():
"""
Feature: JIT Fallback
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
def func3303():
x = np.array([1, 2, 3])
y = np.array([5, 6, 7])
k = []
if x[2] < y[0]:
y = y - x
for i in y:
k.append(i)
z = Tensor(k)
out = 0
for i in z:
out = out * i
return out
res = func3303()
assert res == 64
def test_for_after_for_in_if_4():
"""
Feature: JIT Fallback