forked from mindspore-Ecosystem/mindspore
Add len with check in for
This commit is contained in:
parent
d9f25439aa
commit
fd0a6ae956
|
@ -151,6 +151,7 @@
|
|||
"mindspore/tests/st/networks/test_gpu_alexnet.py" "unused-variable"
|
||||
"mindspore/tests/st/networks/test_gpu_lenet.py" "unused-variable"
|
||||
"mindspore/tests/st/ops/custom_ops_tbe/cus_add3.py" "unused-import"
|
||||
"mindspore/tests/st/control/inner/test_002_single_for.py" "not-an-iterable"
|
||||
"mindspore/tests/st/ops/ascend/test_aicpu_ops/test_strided_slice.py" "redefined-outer-name"
|
||||
"mindspore/tests/st/ops/ascend/test_aicpu_ops/test_strided_slice.py" "redefined-builtin"
|
||||
"mindspore/tests/st/ops/ascend/test_aicpu_ops/test_strided_slice_grad.py" "redefined-outer-name"
|
||||
|
|
|
@ -35,6 +35,7 @@
|
|||
#include "pipeline/jit/debug/trace.h"
|
||||
#include "mindspore/core/ir/cell.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "include/common/utils/python_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parse {
|
||||
|
@ -2114,7 +2115,10 @@ CNodePtr GenerateInterpretGetItem(const FuncGraphPtr &fg, const AnfNodePtr &iter
|
|||
FunctionBlockPtr Parser::ParseForUnroll(const FunctionBlockPtr &block, const py::object &node) {
|
||||
MS_LOG(DEBUG) << "Process ast For by loop variable";
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
AnfNodePtr op_len = block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN);
|
||||
|
||||
const py::function len_with_check = python_adapter::GetPyFn(kStandardMethodModelName, kMsLenWithCheck);
|
||||
auto len_with_check_fg = ParsePythonCode(len_with_check);
|
||||
auto op_len_with_check = NewValueNode(len_with_check_fg);
|
||||
AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
|
||||
|
||||
// Get variable name of 'x' in statement 'for x in xs'
|
||||
|
@ -2128,7 +2132,7 @@ FunctionBlockPtr Parser::ParseForUnroll(const FunctionBlockPtr &block, const py:
|
|||
if (iter_node->interpret() && !IsPrimitiveCNode(iter_node, prim::kPrimPyInterpret)) {
|
||||
iter_node = HandleInterpret(block, iter_node, iter_obj);
|
||||
}
|
||||
CNodePtr scalar_len = block->func_graph()->NewCNodeInOrder({op_len, iter_node});
|
||||
CNodePtr scalar_len = block->func_graph()->NewCNodeInOrder({op_len_with_check, iter_node});
|
||||
FunctionBlockPtr header_block = GenerateBlock(std::make_shared<TraceForHeader>(block->func_graph()->debug_info()));
|
||||
MS_EXCEPTION_IF_NULL(header_block);
|
||||
// Create loop variable 'i'
|
||||
|
|
|
@ -50,6 +50,9 @@ enum ParseStatusCode : int64_t {
|
|||
PARSE_FAILURE = 0xFF
|
||||
};
|
||||
|
||||
constexpr char kStandardMethodModelName[] = "mindspore._extends.parse.standard_method";
|
||||
constexpr char kMsLenWithCheck[] = "ms_len_with_iterable_check";
|
||||
|
||||
// Max loop count of for statement, when loop count is less then this value, the for loop will be unrolled, otherwise it
|
||||
// will be sunk(i.e. not unrolled)
|
||||
// NOTE: Since when the for loop was unrolled, it depends backend operators `tuple_getitem` and `scalar_add` which were
|
||||
|
|
|
@ -2173,6 +2173,24 @@ def ms_len(data):
|
|||
return data.__len__()
|
||||
|
||||
|
||||
@constexpr
|
||||
def python_len_with_check(data):
|
||||
"""Return the result of python built-in len function with iterable check"""
|
||||
if not hasattr(data, "__iter__"):
|
||||
raise TypeError(str(type(data)) + " object is not iterable in graph mode.")
|
||||
return len(data)
|
||||
|
||||
|
||||
def ms_len_with_iterable_check(data):
|
||||
"""Implementation of `len` with iterable check, used in len of condition."""
|
||||
if not isinstance(data, Tensor) and F.isconstant(data):
|
||||
return python_len_with_check(data)
|
||||
if not hasattr(data, "__len__"):
|
||||
type_str = str(F.typeof(data))
|
||||
const_utils.raise_type_error(type_str + " object is not iterable in graph mode.")
|
||||
return data.__len__()
|
||||
|
||||
|
||||
def floor(x):
|
||||
"""Rounds a tensor down to the closest integer element-wise."""
|
||||
return x.__floor__()
|
||||
|
|
|
@ -281,3 +281,26 @@ def test_single_for():
|
|||
input_y = Tensor([2], mstype.int32)
|
||||
res = control_flow_for(input_x, input_y)
|
||||
print("res:", res)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_single_for_with_not_iterable_object():
|
||||
"""
|
||||
Feature: The else branches of for loops aren't supported.
|
||||
Description: The else branches of for loops aren't supported.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def control_flow_for_with_not_iterable_object():
|
||||
ret = 0
|
||||
a = 1
|
||||
for i in a:
|
||||
ret = ret + i
|
||||
return ret
|
||||
|
||||
with pytest.raises(TypeError, match="object is not iterable in graph mode"):
|
||||
control_flow_for_with_not_iterable_object()
|
||||
|
|
|
@ -156,7 +156,7 @@ def test_single_for_wrong_xs():
|
|||
|
||||
with pytest.raises(TypeError) as info:
|
||||
control_flow_for()
|
||||
assert "has no len" in str(info.value)
|
||||
assert "object is not iterable in graph mode" in str(info.value)
|
||||
|
||||
|
||||
def test_single_for_wrong_xs_2():
|
||||
|
@ -175,4 +175,4 @@ def test_single_for_wrong_xs_2():
|
|||
|
||||
with pytest.raises(TypeError) as info:
|
||||
control_flow_for()
|
||||
assert "has no len" in str(info.value)
|
||||
assert "object is not iterable in graph mode" in str(info.value)
|
||||
|
|
Loading…
Reference in New Issue