forked from mindspore-Ecosystem/mindspore
!36239 Enable syntax of "for x in xs", when xs is an interpret node
Merge pull request !36239 from LiangZhibo/for
This commit is contained in:
commit
c39a0b5969
|
@ -1945,6 +1945,38 @@ FunctionBlockPtr Parser::ParseFor(const FunctionBlockPtr &block, const py::objec
|
|||
return ParseForUnroll(block, node);
|
||||
}
|
||||
|
||||
AnfNodePtr Parser::ConvertInterpretIterNodeToList(const FunctionBlockPtr &block, const AnfNodePtr &iter_node,
|
||||
const py::object iter_obj) {
|
||||
// For interpret iter_node, convert it to list. xs --> list(xs).
|
||||
py::object iter_id = python_adapter::GetPyObjAttr(iter_obj, "id");
|
||||
if (!py::isinstance<py::none>(iter_id)) {
|
||||
// If variable is assigned, for example:
|
||||
// xs = np.array([1, 2, 3, 4])
|
||||
// for x in xs
|
||||
const std::string &iter_id_str = iter_id.cast<py::str>();
|
||||
return MakeInterpretNode(block, iter_node, "list(" + iter_id_str + ")");
|
||||
}
|
||||
// If variable is not assigned, for example:
|
||||
// for x in np.array([1, 2, 3, 4])
|
||||
const auto &interpret_iter_node =
|
||||
IsPrimitiveCNode(iter_node, prim::kPrimPyInterpret) ? iter_node : HandleInterpret(block, iter_node, iter_obj);
|
||||
constexpr size_t script_index = 1;
|
||||
auto iter_cnode = interpret_iter_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(iter_cnode);
|
||||
auto iter_cnode_inputs = iter_cnode->inputs();
|
||||
auto iter_script_input = iter_cnode_inputs[script_index];
|
||||
if (!IsValueNode<Script>(iter_script_input)) {
|
||||
MS_LOG(EXCEPTION) << "The second input to iter node: " << interpret_iter_node->DebugString()
|
||||
<< " should be a script value node but got: " << iter_script_input->DebugString() << ".";
|
||||
}
|
||||
auto script = iter_script_input->cast<ValueNodePtr>();
|
||||
auto script_val = script->value()->cast<ScriptPtr>();
|
||||
auto script_text = script_val->script();
|
||||
auto new_script_val = NewValueNode(std::make_shared<Script>("list(" + script_text + ")"));
|
||||
iter_cnode_inputs[script_index] = new_script_val;
|
||||
return block->func_graph()->NewCNodeInOrder(iter_cnode_inputs);
|
||||
}
|
||||
|
||||
// Implement unroll for statement with tuple/getitem.
|
||||
FunctionBlockPtr Parser::ParseForUnroll(const FunctionBlockPtr &block, const py::object &node) {
|
||||
MS_LOG(DEBUG) << "Process ast For by loop variable";
|
||||
|
@ -1962,9 +1994,7 @@ FunctionBlockPtr Parser::ParseForUnroll(const FunctionBlockPtr &block, const py:
|
|||
bool interpret_without_internal =
|
||||
IsPrimitiveCNode(iter_node, prim::kPrimPyInterpret) && !iter_node->interpret_internal_type();
|
||||
if (iter_node->interpret() || interpret_without_internal) {
|
||||
MS_EXCEPTION(TypeError) << "Parsing syntax 'for x in xs', xs can not be interpret node but got "
|
||||
<< iter_node->DebugString()
|
||||
<< ".\nNodeInfo: " << trace::GetDebugInfo(iter_node->debug_info());
|
||||
iter_node = ConvertInterpretIterNodeToList(block, iter_node, iter_obj);
|
||||
}
|
||||
// Generate node for loop count and convert it to tensor, to make the loop not unroll
|
||||
CNodePtr scalar_len = block->func_graph()->NewCNodeInOrder({op_len, iter_node});
|
||||
|
|
|
@ -237,6 +237,9 @@ class Parser {
|
|||
void UpdateInterpretForUserNode(const AnfNodePtr &user_node, const std::vector<AnfNodePtr> &nodes);
|
||||
// Make interpret node.
|
||||
AnfNodePtr MakeInterpretNode(const FunctionBlockPtr &block, const AnfNodePtr &value_node, const string &script_text);
|
||||
// Convert interpret iter node to list.
|
||||
AnfNodePtr ConvertInterpretIterNodeToList(const FunctionBlockPtr &block, const AnfNodePtr &iter_node,
|
||||
const py::object iter_obj);
|
||||
// Check if the node need interpreting.
|
||||
AnfNodePtr HandleInterpret(const FunctionBlockPtr &block, const AnfNodePtr &value_node,
|
||||
const py::object &value_object);
|
||||
|
|
|
@ -93,7 +93,7 @@ class Script final : public Named {
|
|||
~Script() override = default;
|
||||
MS_DECLARE_PARENT(Script, Named);
|
||||
|
||||
std::string script() { return script_; }
|
||||
std::string script() const { return script_; }
|
||||
abstract::AbstractBasePtr ToAbstract() override {
|
||||
return std::make_shared<abstract::AbstractScript>(shared_from_base<Script>());
|
||||
}
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test graph fallback control flow."""
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import Tensor, ms_function, context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_for_after_for_in_if_3():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def func3303():
|
||||
x = np.array([1, 2, 3])
|
||||
y = np.array([5, 6, 7])
|
||||
k = []
|
||||
if x[2] < y[0]:
|
||||
y = y - x
|
||||
for i in y:
|
||||
k.append(i)
|
||||
|
||||
z = Tensor(k)
|
||||
out = 1
|
||||
for i in z:
|
||||
out = out * i
|
||||
return out
|
||||
|
||||
res = func3303()
|
||||
assert res == 64
|
|
@ -121,6 +121,58 @@ def test_single_for_x_in_xs():
|
|||
for i in x:
|
||||
y += i
|
||||
return Tensor(y)
|
||||
with pytest.raises(TypeError) as err:
|
||||
res = control_flow_for()
|
||||
assert np.allclose(res.asnumpy(), 3.3)
|
||||
|
||||
|
||||
def test_single_for_x_in_xs_2():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def control_flow_for():
|
||||
y = np.array(0)
|
||||
for i in np.array([1.1, 2.2]):
|
||||
y += i
|
||||
return Tensor(y)
|
||||
res = control_flow_for()
|
||||
assert np.allclose(res.asnumpy(), 3.3)
|
||||
|
||||
|
||||
def test_single_for_wrong_xs():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def control_flow_for():
|
||||
y = np.array(0)
|
||||
for i in np.int64(1):
|
||||
y += i
|
||||
return Tensor(y)
|
||||
|
||||
with pytest.raises(TypeError) as info:
|
||||
control_flow_for()
|
||||
assert "Parsing syntax 'for x in xs', xs can not be interpret node " in str(err.value)
|
||||
assert "object is not iterable" in str(info.value)
|
||||
|
||||
|
||||
def test_single_for_wrong_xs_2():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def control_flow_for():
|
||||
x = np.int64(1)
|
||||
y = np.array(0)
|
||||
for i in x:
|
||||
y += i
|
||||
return Tensor(y)
|
||||
|
||||
with pytest.raises(TypeError) as info:
|
||||
control_flow_for()
|
||||
assert "object is not iterable" in str(info.value)
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test graph fallback control flow for after if in if scenario"""
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import Tensor, ms_function, context
|
||||
|
||||
|
@ -94,7 +93,6 @@ def test_for_after_if_in_if_numpy():
|
|||
assert np.all(res.asnumpy() == np.array([23, 24, 25]))
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support to get attribute for InterpretObject.')
|
||||
def test_for_after_if_in_if_numpy_2():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test graph fallback control flow for after if in while scenario"""
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import Tensor, ms_function, context
|
||||
|
||||
|
@ -42,7 +41,6 @@ def test_for_after_if_in_while_numpy():
|
|||
assert res == -4
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support to get attribute for InterpretObject.')
|
||||
def test_for_after_if_in_while_numpy_2():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
@ -62,4 +60,4 @@ def test_for_after_if_in_while_numpy_2():
|
|||
y += i - 20
|
||||
return sum(y)
|
||||
res = control_flow_for_after_if_in_while()
|
||||
assert res == 18
|
||||
assert res == -222
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test graph fallback control flow for after if in for scenario"""
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import Tensor, ms_function, context
|
||||
|
||||
|
@ -66,7 +65,6 @@ def test_for_after_if_in_for_tensor_2():
|
|||
assert res == 1
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support to get attribute for InterpretObject.')
|
||||
def test_for_after_if_in_for_numpy():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
@ -82,14 +80,13 @@ def test_for_after_if_in_for_numpy():
|
|||
y += 1
|
||||
if sum(x) > 15:
|
||||
break
|
||||
for _ in y:
|
||||
x += y
|
||||
for i in y:
|
||||
x += i
|
||||
return Tensor(max(x))
|
||||
res = control_flow_for_after_if_in_for()
|
||||
assert res == 11
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support to get attribute for InterpretObject.')
|
||||
def test_for_after_if_in_for_numpy_2():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
@ -110,6 +107,6 @@ def test_for_after_if_in_for_numpy_2():
|
|||
y += 2
|
||||
for i in range(3):
|
||||
a += y[i]
|
||||
return Tensor(max(x))
|
||||
return Tensor(a)
|
||||
res = control_flow_for_after_if_in_for()
|
||||
assert res == 17
|
||||
assert res == 26
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test graph fallback control flow."""
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import Tensor, ms_function, context
|
||||
|
||||
|
@ -70,34 +69,6 @@ def test_for_after_for_in_if_2():
|
|||
assert res_y == 4
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
def test_for_after_for_in_if_3():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def func3303():
|
||||
x = np.array([1, 2, 3])
|
||||
y = np.array([5, 6, 7])
|
||||
k = []
|
||||
if x[2] < y[0]:
|
||||
y = y - x
|
||||
for i in y:
|
||||
k.append(i)
|
||||
|
||||
z = Tensor(k)
|
||||
out = 0
|
||||
for i in z:
|
||||
out = out * i
|
||||
return out
|
||||
|
||||
res = func3303()
|
||||
assert res == 64
|
||||
|
||||
|
||||
def test_for_after_for_in_if_4():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
|
Loading…
Reference in New Issue