!3450 use primitive `Less` instead of `scalar_lt` for `large for loop` condition

Merge pull request !3450 from fary86/fix_large_for_loop_execute_error
This commit is contained in:
mindspore-ci-bot 2020-07-27 09:48:54 +08:00 committed by Gitee
commit 15dfaf6b97
7 changed files with 17 additions and 24 deletions

View File

@ -185,12 +185,6 @@ inline const PrimitivePtr kPrimSparseTensorGetValues = std::make_shared<Primitiv
inline const PrimitivePtr kPrimSparseTensorGetIndices = std::make_shared<Primitive>("SparseTensorGetIndices");
inline const PrimitivePtr kPrimSparseTensorGetDenseShape = std::make_shared<Primitive>("SparseTensorGetDenseShape");
// attribute 'unroll_flag' of primitive 'switch', when 'unroll_flag' is '0', 'switch' will not unroll
const char SWITCH_UNROLL_FLAG[] = "unroll_flag";
// max loop count of for statement, when loop count is less then this value, the for loop will be unrolled, otherwise it
// will be sunk(i.e. not unrolled)
const int MAX_FOR_LOOP_COUNT = 600;
class UnpackGraphPrimitive : public Primitive {
public:
explicit UnpackGraphPrimitive(const std::string &name, const bool &with_sens, const bool &need_unpack_args)

View File

@ -108,11 +108,6 @@ AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &p
auto fb = args_spec_list[2];
MS_EXCEPTION_IF_NULL(cond);
auto unroll_flag = prim->GetAttr(prim::SWITCH_UNROLL_FLAG);
if (unroll_flag != nullptr && GetValue<int>(unroll_flag) == 0) {
return tb->Join(fb);
}
ValuePtr v = cond->GetValueTrack();
MS_EXCEPTION_IF_NULL(v);
// for tensor as condition, keeps both true and false branch.

View File

@ -298,13 +298,8 @@ void FunctionBlock::ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr
MS_LOG(EXCEPTION) << "Failure: have return node! NodeInfo: "
<< trace::GetDebugInfo(func_graph()->get_return()->debug_info());
}
// Here we need set an attribute to primtive 'switch', so we create a new variable instead of global 'kPrimSwitch'
auto prim_switch = std::make_shared<Primitive>(prim::kPrimSwitch->name());
if (!unroll_loop) {
prim_switch->AddAttr(prim::SWITCH_UNROLL_FLAG, MakeValue(0));
}
CNodePtr switch_app =
func_graph()->NewCNode({NewValueNode(prim_switch), condNode, NewValueNode(true_block->func_graph()),
func_graph()->NewCNode({NewValueNode(prim::kPrimSwitch), condNode, NewValueNode(true_block->func_graph()),
NewValueNode(false_block->func_graph())});
CNodePtr switch_app_new = func_graph()->NewCNode({switch_app});
func_graph()->set_output(switch_app_new);

View File

@ -1061,13 +1061,13 @@ FunctionBlockPtr Parser::GenerateBlockInFor(const TraceInfoPtr &trace_info) {
FunctionBlockPtr Parser::ParseFor(const FunctionBlockPtr &block, const py::object &node) {
MS_LOG(DEBUG) << "Process ast For, create an if else statement";
MS_EXCEPTION_IF_NULL(block);
// create statement 'len(xs) < prim::MAX_FOR_LOOP_COUNT'
// create statement 'len(xs) < MAX_FOR_LOOP_COUNT'
AnfNodePtr op_len = block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN);
py::object iter_obj = python_adapter::GetPyObjAttr(node, NAMED_PRIMITIVE_ITER);
AnfNodePtr iter_node = ParseExprNode(block, iter_obj);
CNodePtr len_iter = block->func_graph()->NewCNode({op_len, iter_node});
CNodePtr bool_node = block->func_graph()->NewCNode(
{NewValueNode(prim::kPrimScalarLt), len_iter, NewValueNode(prim::MAX_FOR_LOOP_COUNT)});
CNodePtr bool_node =
block->func_graph()->NewCNode({NewValueNode(prim::kPrimScalarLt), len_iter, NewValueNode(MAX_FOR_LOOP_COUNT)});
// create statement 'if len(xs) < prim::MAX_FOR_LOOP_COUNT then ParseForIter else ParseForLoop'
TraceManager::DebugTrace(std::make_shared<TraceIfStmtTrueBranch>(block->func_graph()->debug_info()));
@ -1191,7 +1191,12 @@ FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::o
py::object iter_obj = python_adapter::GetPyObjAttr(node, "iter");
AnfNodePtr iter_node = ParseExprNode(block, iter_obj);
MS_EXCEPTION_IF_NULL(iter_node);
CNodePtr len_iter = block->func_graph()->NewCNode({op_len, iter_node});
// Generate node for loop count and convert it to tensor, to make the loop not unroll
CNodePtr scalar_len = block->func_graph()->NewCNode({op_len, iter_node});
auto scalar_to_tensor = prim::GetPythonOps("ScalarToTensor", "mindspore.ops.operations");
auto scalar_to_tensor_node = block->func_graph()->NewCNode({NewValueNode(scalar_to_tensor)});
CNodePtr len_iter = block->func_graph()->NewCNode({scalar_to_tensor_node, scalar_len});
FunctionBlockPtr header_block =
GenerateBlockInFor(std::make_shared<TraceForHeader>(block->func_graph()->debug_info()));
@ -1199,7 +1204,9 @@ FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::o
// create loop variable 'i'
ParameterPtr loop_var = header_block->func_graph()->add_parameter();
// create loop condition 'i < len(xs)'
CNodePtr cond_node = header_block->func_graph()->NewCNode({NewValueNode(prim::kPrimScalarLt), loop_var, len_iter});
auto prim_less = prim::GetPythonOps("Less", "mindspore.ops.operations");
auto less_node = header_block->func_graph()->NewCNode({NewValueNode(prim_less)});
CNodePtr cond_node = header_block->func_graph()->NewCNode({less_node, loop_var, len_iter});
// generate the body of the for statement
FunctionBlockPtr body_block = GenerateBlockInFor(std::make_shared<TraceForBody>(block->func_graph()->debug_info()));

View File

@ -48,6 +48,10 @@ enum ParseStatusCode : int {
PARSE_FAILURE = 0xFF
};
// max loop count of for statement, when loop count is less then this value, the for loop will be unrolled, otherwise it
// will be sunk(i.e. not unrolled)
const int MAX_FOR_LOOP_COUNT = 600;
class AstNodeType;
class ParseAst;

View File

@ -24,7 +24,6 @@ namespace mindspore {
REGISTER_PYBIND_DEFINE(FuncGraph, ([](const pybind11::module *m) {
// Define python "MetaFuncGraph_" class
(void)py::class_<MetaFuncGraph, std::shared_ptr<MetaFuncGraph>>(*m, "MetaFuncGraph_")
// .def_readonly(PYTHON_METAFUNCGRAPH_FLAG, &MetaFuncGraph::parse_info_)
.def(py::init<std::string &>());
// Define python "FuncGraph" class
(void)py::class_<FuncGraph, FuncGraphPtr>(*m, "FuncGraph")

View File

@ -72,7 +72,6 @@ class MetaFuncGraph : public FuncGraphBase {
return false;
}
}
// const bool parse_info_ = true;
protected:
template <typename Derived>