forked from OSSInnovation/mindspore
Fix large for loop execute fail
This commit is contained in:
parent
ea545dc52f
commit
d4a3d0fa14
|
@ -185,12 +185,6 @@ inline const PrimitivePtr kPrimSparseTensorGetValues = std::make_shared<Primitiv
|
|||
inline const PrimitivePtr kPrimSparseTensorGetIndices = std::make_shared<Primitive>("SparseTensorGetIndices");
|
||||
inline const PrimitivePtr kPrimSparseTensorGetDenseShape = std::make_shared<Primitive>("SparseTensorGetDenseShape");
|
||||
|
||||
// attribute 'unroll_flag' of primitive 'switch', when 'unroll_flag' is '0', 'switch' will not unroll
|
||||
const char SWITCH_UNROLL_FLAG[] = "unroll_flag";
|
||||
// max loop count of for statement, when loop count is less then this value, the for loop will be unrolled, otherwise it
|
||||
// will be sunk(i.e. not unrolled)
|
||||
const int MAX_FOR_LOOP_COUNT = 600;
|
||||
|
||||
class UnpackGraphPrimitive : public Primitive {
|
||||
public:
|
||||
explicit UnpackGraphPrimitive(const std::string &name, const bool &with_sens, const bool &need_unpack_args)
|
||||
|
|
|
@ -108,11 +108,6 @@ AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &p
|
|||
auto fb = args_spec_list[2];
|
||||
MS_EXCEPTION_IF_NULL(cond);
|
||||
|
||||
auto unroll_flag = prim->GetAttr(prim::SWITCH_UNROLL_FLAG);
|
||||
if (unroll_flag != nullptr && GetValue<int>(unroll_flag) == 0) {
|
||||
return tb->Join(fb);
|
||||
}
|
||||
|
||||
ValuePtr v = cond->GetValueTrack();
|
||||
MS_EXCEPTION_IF_NULL(v);
|
||||
// for tensor as condition, keeps both true and false branch.
|
||||
|
|
|
@ -298,13 +298,8 @@ void FunctionBlock::ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr
|
|||
MS_LOG(EXCEPTION) << "Failure: have return node! NodeInfo: "
|
||||
<< trace::GetDebugInfo(func_graph()->get_return()->debug_info());
|
||||
}
|
||||
// Here we need set an attribute to primtive 'switch', so we create a new variable instead of global 'kPrimSwitch'
|
||||
auto prim_switch = std::make_shared<Primitive>(prim::kPrimSwitch->name());
|
||||
if (!unroll_loop) {
|
||||
prim_switch->AddAttr(prim::SWITCH_UNROLL_FLAG, MakeValue(0));
|
||||
}
|
||||
CNodePtr switch_app =
|
||||
func_graph()->NewCNode({NewValueNode(prim_switch), condNode, NewValueNode(true_block->func_graph()),
|
||||
func_graph()->NewCNode({NewValueNode(prim::kPrimSwitch), condNode, NewValueNode(true_block->func_graph()),
|
||||
NewValueNode(false_block->func_graph())});
|
||||
CNodePtr switch_app_new = func_graph()->NewCNode({switch_app});
|
||||
func_graph()->set_output(switch_app_new);
|
||||
|
|
|
@ -1061,13 +1061,13 @@ FunctionBlockPtr Parser::GenerateBlockInFor(const TraceInfoPtr &trace_info) {
|
|||
FunctionBlockPtr Parser::ParseFor(const FunctionBlockPtr &block, const py::object &node) {
|
||||
MS_LOG(DEBUG) << "Process ast For, create an if else statement";
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
// create statement 'len(xs) < prim::MAX_FOR_LOOP_COUNT'
|
||||
// create statement 'len(xs) < MAX_FOR_LOOP_COUNT'
|
||||
AnfNodePtr op_len = block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN);
|
||||
py::object iter_obj = python_adapter::GetPyObjAttr(node, NAMED_PRIMITIVE_ITER);
|
||||
AnfNodePtr iter_node = ParseExprNode(block, iter_obj);
|
||||
CNodePtr len_iter = block->func_graph()->NewCNode({op_len, iter_node});
|
||||
CNodePtr bool_node = block->func_graph()->NewCNode(
|
||||
{NewValueNode(prim::kPrimScalarLt), len_iter, NewValueNode(prim::MAX_FOR_LOOP_COUNT)});
|
||||
CNodePtr bool_node =
|
||||
block->func_graph()->NewCNode({NewValueNode(prim::kPrimScalarLt), len_iter, NewValueNode(MAX_FOR_LOOP_COUNT)});
|
||||
|
||||
// create statement 'if len(xs) < prim::MAX_FOR_LOOP_COUNT then ParseForIter else ParseForLoop'
|
||||
TraceManager::DebugTrace(std::make_shared<TraceIfStmtTrueBranch>(block->func_graph()->debug_info()));
|
||||
|
@ -1191,7 +1191,12 @@ FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::o
|
|||
py::object iter_obj = python_adapter::GetPyObjAttr(node, "iter");
|
||||
AnfNodePtr iter_node = ParseExprNode(block, iter_obj);
|
||||
MS_EXCEPTION_IF_NULL(iter_node);
|
||||
CNodePtr len_iter = block->func_graph()->NewCNode({op_len, iter_node});
|
||||
// Generate node for loop count and convert it to tensor, to make the loop not unroll
|
||||
CNodePtr scalar_len = block->func_graph()->NewCNode({op_len, iter_node});
|
||||
auto scalar_to_tensor = prim::GetPythonOps("ScalarToTensor", "mindspore.ops.operations");
|
||||
auto scalar_to_tensor_node = block->func_graph()->NewCNode({NewValueNode(scalar_to_tensor)});
|
||||
|
||||
CNodePtr len_iter = block->func_graph()->NewCNode({scalar_to_tensor_node, scalar_len});
|
||||
|
||||
FunctionBlockPtr header_block =
|
||||
GenerateBlockInFor(std::make_shared<TraceForHeader>(block->func_graph()->debug_info()));
|
||||
|
@ -1199,7 +1204,9 @@ FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::o
|
|||
// create loop variable 'i'
|
||||
ParameterPtr loop_var = header_block->func_graph()->add_parameter();
|
||||
// create loop condition 'i < len(xs)'
|
||||
CNodePtr cond_node = header_block->func_graph()->NewCNode({NewValueNode(prim::kPrimScalarLt), loop_var, len_iter});
|
||||
auto prim_less = prim::GetPythonOps("Less", "mindspore.ops.operations");
|
||||
auto less_node = header_block->func_graph()->NewCNode({NewValueNode(prim_less)});
|
||||
CNodePtr cond_node = header_block->func_graph()->NewCNode({less_node, loop_var, len_iter});
|
||||
|
||||
// generate the body of the for statement
|
||||
FunctionBlockPtr body_block = GenerateBlockInFor(std::make_shared<TraceForBody>(block->func_graph()->debug_info()));
|
||||
|
|
|
@ -48,6 +48,10 @@ enum ParseStatusCode : int {
|
|||
PARSE_FAILURE = 0xFF
|
||||
};
|
||||
|
||||
// max loop count of for statement, when loop count is less then this value, the for loop will be unrolled, otherwise it
|
||||
// will be sunk(i.e. not unrolled)
|
||||
const int MAX_FOR_LOOP_COUNT = 600;
|
||||
|
||||
class AstNodeType;
|
||||
class ParseAst;
|
||||
|
||||
|
|
|
@ -24,7 +24,6 @@ namespace mindspore {
|
|||
REGISTER_PYBIND_DEFINE(FuncGraph, ([](const pybind11::module *m) {
|
||||
// Define python "MetaFuncGraph_" class
|
||||
(void)py::class_<MetaFuncGraph, std::shared_ptr<MetaFuncGraph>>(*m, "MetaFuncGraph_")
|
||||
// .def_readonly(PYTHON_METAFUNCGRAPH_FLAG, &MetaFuncGraph::parse_info_)
|
||||
.def(py::init<std::string &>());
|
||||
// Define python "FuncGraph" class
|
||||
(void)py::class_<FuncGraph, FuncGraphPtr>(*m, "FuncGraph")
|
||||
|
|
|
@ -72,7 +72,6 @@ class MetaFuncGraph : public FuncGraphBase {
|
|||
return false;
|
||||
}
|
||||
}
|
||||
// const bool parse_info_ = true;
|
||||
|
||||
protected:
|
||||
template <typename Derived>
|
||||
|
|
Loading…
Reference in New Issue