Add len with check in for

This commit is contained in:
liangzhibo 2022-09-26 10:27:41 +08:00
parent d9f25439aa
commit fd0a6ae956
6 changed files with 53 additions and 4 deletions

View File

@ -151,6 +151,7 @@
"mindspore/tests/st/networks/test_gpu_alexnet.py" "unused-variable"
"mindspore/tests/st/networks/test_gpu_lenet.py" "unused-variable"
"mindspore/tests/st/ops/custom_ops_tbe/cus_add3.py" "unused-import"
"mindspore/tests/st/control/inner/test_002_single_for.py" "not-an-iterable"
"mindspore/tests/st/ops/ascend/test_aicpu_ops/test_strided_slice.py" "redefined-outer-name"
"mindspore/tests/st/ops/ascend/test_aicpu_ops/test_strided_slice.py" "redefined-builtin"
"mindspore/tests/st/ops/ascend/test_aicpu_ops/test_strided_slice_grad.py" "redefined-outer-name"

View File

@ -35,6 +35,7 @@
#include "pipeline/jit/debug/trace.h"
#include "mindspore/core/ir/cell.h"
#include "include/common/utils/utils.h"
#include "include/common/utils/python_adapter.h"
namespace mindspore {
namespace parse {
@ -2114,7 +2115,10 @@ CNodePtr GenerateInterpretGetItem(const FuncGraphPtr &fg, const AnfNodePtr &iter
FunctionBlockPtr Parser::ParseForUnroll(const FunctionBlockPtr &block, const py::object &node) {
MS_LOG(DEBUG) << "Process ast For by loop variable";
MS_EXCEPTION_IF_NULL(block);
AnfNodePtr op_len = block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN);
const py::function len_with_check = python_adapter::GetPyFn(kStandardMethodModelName, kMsLenWithCheck);
auto len_with_check_fg = ParsePythonCode(len_with_check);
auto op_len_with_check = NewValueNode(len_with_check_fg);
AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
// Get variable name of 'x' in statement 'for x in xs'
@ -2128,7 +2132,7 @@ FunctionBlockPtr Parser::ParseForUnroll(const FunctionBlockPtr &block, const py:
if (iter_node->interpret() && !IsPrimitiveCNode(iter_node, prim::kPrimPyInterpret)) {
iter_node = HandleInterpret(block, iter_node, iter_obj);
}
CNodePtr scalar_len = block->func_graph()->NewCNodeInOrder({op_len, iter_node});
CNodePtr scalar_len = block->func_graph()->NewCNodeInOrder({op_len_with_check, iter_node});
FunctionBlockPtr header_block = GenerateBlock(std::make_shared<TraceForHeader>(block->func_graph()->debug_info()));
MS_EXCEPTION_IF_NULL(header_block);
// Create loop variable 'i'

View File

@ -50,6 +50,9 @@ enum ParseStatusCode : int64_t {
PARSE_FAILURE = 0xFF
};
constexpr char kStandardMethodModelName[] = "mindspore._extends.parse.standard_method";
constexpr char kMsLenWithCheck[] = "ms_len_with_iterable_check";
// Max loop count of for statement, when loop count is less then this value, the for loop will be unrolled, otherwise it
// will be sunk(i.e. not unrolled)
// NOTE: Since when the for loop was unrolled, it depends backend operators `tuple_getitem` and `scalar_add` which were

View File

@ -2173,6 +2173,24 @@ def ms_len(data):
return data.__len__()
@constexpr
def python_len_with_check(data):
"""Return the result of python built-in len function with iterable check"""
if not hasattr(data, "__iter__"):
raise TypeError(str(type(data)) + " object is not iterable in graph mode.")
return len(data)
def ms_len_with_iterable_check(data):
"""Implementation of `len` with iterable check, used in len of condition."""
if not isinstance(data, Tensor) and F.isconstant(data):
return python_len_with_check(data)
if not hasattr(data, "__len__"):
type_str = str(F.typeof(data))
const_utils.raise_type_error(type_str + " object is not iterable in graph mode.")
return data.__len__()
def floor(x):
"""Rounds a tensor down to the closest integer element-wise."""
return x.__floor__()

View File

@ -281,3 +281,26 @@ def test_single_for():
input_y = Tensor([2], mstype.int32)
res = control_flow_for(input_x, input_y)
print("res:", res)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_single_for_with_not_iterable_object():
"""
Feature: The else branches of for loops aren't supported.
Description: The else branches of for loops aren't supported.
Expectation: No exception.
"""
@ms_function
def control_flow_for_with_not_iterable_object():
ret = 0
a = 1
for i in a:
ret = ret + i
return ret
with pytest.raises(TypeError, match="object is not iterable in graph mode"):
control_flow_for_with_not_iterable_object()

View File

@ -156,7 +156,7 @@ def test_single_for_wrong_xs():
with pytest.raises(TypeError) as info:
control_flow_for()
assert "has no len" in str(info.value)
assert "object is not iterable in graph mode" in str(info.value)
def test_single_for_wrong_xs_2():
@ -175,4 +175,4 @@ def test_single_for_wrong_xs_2():
with pytest.raises(TypeError) as info:
control_flow_for()
assert "has no len" in str(info.value)
assert "object is not iterable in graph mode" in str(info.value)