From 5a4d86caa8d5034b39b8c96420dfdb334831a2a8 Mon Sep 17 00:00:00 2001 From: liangzhibo Date: Mon, 20 Jun 2022 10:18:44 +0800 Subject: [PATCH] for x in xs syntax when xs is numpy array --- mindspore/ccsrc/pipeline/jit/parse/parse.cc | 36 +++++++++++- mindspore/ccsrc/pipeline/jit/parse/parse.h | 3 + mindspore/ccsrc/pipeline/jit/parse/resolve.h | 2 +- .../test_fallback_330_for_after_for_in_if.py | 52 +++++++++++++++++ .../test_fallback_002_single_for.py | 56 ++++++++++++++++++- .../test_fallback_310_for_after_if_in_if.py | 2 - ...test_fallback_311_for_after_if_in_while.py | 4 +- .../test_fallback_312_for_after_if_in_for.py | 11 ++-- .../test_fallback_330_for_after_for_in_if.py | 29 ---------- 9 files changed, 148 insertions(+), 47 deletions(-) create mode 100644 tests/st/fallback/control_flow/test_fallback_330_for_after_for_in_if.py diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse.cc b/mindspore/ccsrc/pipeline/jit/parse/parse.cc index 1c07dcfd0b8..65e4d41f956 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/parse.cc @@ -1950,6 +1950,38 @@ FunctionBlockPtr Parser::ParseFor(const FunctionBlockPtr &block, const py::objec return ParseForUnroll(block, node); } +AnfNodePtr Parser::ConvertInterpretIterNodeToList(const FunctionBlockPtr &block, const AnfNodePtr &iter_node, + const py::object iter_obj) { + // For interpret iter_node, convert it to list. xs --> list(xs). + py::object iter_id = python_adapter::GetPyObjAttr(iter_obj, "id"); + if (!py::isinstance(iter_id)) { + // If variable is assigned, for example: + // xs = np.array([1, 2, 3, 4]) + // for x in xs + const std::string &iter_id_str = iter_id.cast(); + return MakeInterpretNode(block, iter_node, "list(" + iter_id_str + ")"); + } + // If variable is not assigned, for example: + // for x in np.array([1, 2, 3, 4]) + const auto &interpret_iter_node = + IsPrimitiveCNode(iter_node, prim::kPrimPyInterpret) ? iter_node : HandleInterpret(block, iter_node, iter_obj); + constexpr size_t script_index = 1; + auto iter_cnode = interpret_iter_node->cast(); + MS_EXCEPTION_IF_NULL(iter_cnode); + auto iter_cnode_inputs = iter_cnode->inputs(); + auto iter_script_input = iter_cnode_inputs[script_index]; + if (!IsValueNode