Support ListComp and GeneratorExp in Graph Mode.

This commit is contained in:
Zhang Qinghua 2021-08-04 16:15:11 +08:00
parent dfd3f92858
commit aef396f9f2
8 changed files with 380 additions and 142 deletions

View File

@ -20,6 +20,7 @@
#include <string>
#include <memory>
#include <algorithm>
#include "pybind11/pybind11.h"
#include "pipeline/jit/parse/resolve.h"
@ -329,10 +330,10 @@ bool FunctionBlock::CollectRemovablePhi(const ParameterPtr &phi) {
// A block should be marked matured if its predecessor blocks have been processed
void FunctionBlock::Mature() {
const auto &graphParamVec = func_graph_->parameters();
for (auto &paramItr : graphParamVec) {
MS_EXCEPTION_IF_NULL(paramItr);
auto param = paramItr->cast<ParameterPtr>();
const auto &graph_params = func_graph_->parameters();
for (auto &param_itr : graph_params) {
MS_EXCEPTION_IF_NULL(param_itr);
auto param = param_itr->cast<ParameterPtr>();
if (phi_nodes_.find(param) != phi_nodes_.cend()) {
SetPhiArgument(param);
}
@ -356,7 +357,7 @@ CNodePtr FunctionBlock::ForceToWhileCond(const AnfNodePtr &cond) {
}
// Perform a jump from this block to target block
void FunctionBlock::Jump(const FunctionBlockPtr &target_block, const AnfNodePtr &node) {
void FunctionBlock::Jump(const FunctionBlockPtr &target_block, const std::vector<AnfNodePtr> &args) {
MS_EXCEPTION_IF_NULL(target_block);
if (func_graph_->get_return() != nullptr) {
MS_LOG(EXCEPTION) << "Failure: have return node! NodeInfo: "
@ -364,9 +365,7 @@ void FunctionBlock::Jump(const FunctionBlockPtr &target_block, const AnfNodePtr
}
std::vector<AnfNodePtr> input_nodes;
input_nodes.emplace_back(NewValueNode(target_block->func_graph()));
if (node != nullptr) {
input_nodes.emplace_back(node);
}
(void)std::copy(args.begin(), args.end(), std::back_inserter(input_nodes));
CNodePtr jump = func_graph_->NewCNodeInOrder(input_nodes);
jumps_[target_block.get()] = jump;

View File

@ -57,7 +57,7 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
void Mature();
CNodePtr ForceToBoolNode(const AnfNodePtr &cond);
CNodePtr ForceToWhileCond(const AnfNodePtr &cond);
void Jump(const FunctionBlockPtr &block, const AnfNodePtr &node);
void Jump(const FunctionBlockPtr &block, const std::vector<AnfNodePtr> &args);
AnfNodePtr SearchReplaceNode(const std::string &var, const ParameterPtr &phi);
void ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr &trueBlock, const FunctionBlockPtr &falseBlock,
bool unroll_loop = true);

View File

@ -130,6 +130,8 @@ void Parser::BuildMethodMap() {
expr_method_map_["UnaryOp"] = &Parser::ParseUnaryOp;
expr_method_map_["Dict"] = &Parser::ParseDict;
expr_method_map_["Ellipsis"] = &Parser::ParseEllipsis;
expr_method_map_["ListComp"] = &Parser::ParseListComp;
expr_method_map_["GeneratorExp"] = &Parser::ParseListComp; // We treat 'GeneratorExp' the same as 'ListComp'.
}
void Parser::UpdateTopFuncGraph(const FuncGraphPtr &func_graph) { top_func_graph_ = FuncGraphWeakPtr(func_graph); }
@ -156,8 +158,8 @@ void CheckFuncReturn(const FuncGraphPtr &fn, const std::shared_ptr<ParseAst> &as
}
py::object node = ast->GetAstNode();
py::list ret = ast->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node);
constexpr auto kMinListSize = 2;
if (ret.size() < kMinListSize) {
constexpr auto min_list_size = 2;
if (ret.size() < min_list_size) {
MS_LOG(EXCEPTION) << "list size:" << ret.size() << " is less than 2.";
}
py::str desc =
@ -169,18 +171,15 @@ void CheckFuncReturn(const FuncGraphPtr &fn, const std::shared_ptr<ParseAst> &as
FuncGraphPtr Parser::ParseFuncGraph() {
// Get ast FunctionDef node
py::object node = ast_->GetAstNode();
FunctionBlockPtr pFnBlock = ParseFunction(node);
FunctionBlockPtr fn_block = ParseFunction(node);
if (errcode() != PARSE_SUCCESS) {
MS_LOG(ERROR) << "Parse function error, code is " << errcode();
return nullptr;
}
RemoveUnnecessaryPhis();
MS_EXCEPTION_IF_NULL(pFnBlock);
CheckFuncReturn(pFnBlock->func_graph(), ast_);
return pFnBlock->func_graph();
MS_EXCEPTION_IF_NULL(fn_block);
CheckFuncReturn(fn_block->func_graph(), ast_);
return fn_block->func_graph();
}
void Parser::GenerateArgsNodeForFunction(const FunctionBlockPtr &block, const py::object &fn_node) {
@ -261,14 +260,14 @@ FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlo
// The node created in the parsefunction context, will inherit the scope created using scope_guard
ScopeGuard scope_guard(scope);
TraceGuard trace_guard(data_converter::GetObjKey(ast_->obj())[0], GetLocation(node));
FunctionBlockPtr pFunBlock = MakeFunctionBlock(*this);
FunctionBlockPtr func_block = MakeFunctionBlock(*this);
if (block != nullptr) {
pFunBlock->AddPrevBlock(block);
func_block->AddPrevBlock(block);
} else {
func_graph_ = pFunBlock->func_graph();
func_graph_ = func_block->func_graph();
}
pFunBlock->Mature();
auto current_fg = pFunBlock->func_graph();
func_block->Mature();
auto current_fg = func_block->func_graph();
auto function_name = py::cast<std::string>(python_adapter::GetPyObjAttr(node, "name"));
MS_LOG(DEBUG) << "The function name is " << function_name;
current_fg->debug_info()->set_name(function_name);
@ -286,27 +285,27 @@ FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlo
MS_LOG(ERROR) << "Set flags failed";
return nullptr;
}
GenerateArgsNodeForFunction(pFunBlock, node);
GenerateArgsNodeForFunction(func_block, node);
// When parsing the top graph of construct, save the top graph
if (GetTopFuncGraph() == nullptr) {
UpdateTopFuncGraph(pFunBlock->func_graph());
UpdateTopFuncGraph(func_block->func_graph());
}
// Save the function node to block
pFunBlock->WriteVariable(function_name, NewValueNode(current_fg));
func_block->WriteVariable(function_name, NewValueNode(current_fg));
py::object funcObj = python_adapter::GetPyObjAttr(node, "body");
(void)ParseStatements(pFunBlock, funcObj);
(void)ParseStatements(func_block, funcObj);
// Add unused variables as isolate nodes.
for (auto &func_block : func_block_list_) {
MS_EXCEPTION_IF_NULL(func_block);
if (func_block->func_graph()->get_return() != nullptr) {
for (auto &func_block_item : func_block_list_) {
MS_EXCEPTION_IF_NULL(func_block_item);
if (func_block_item->func_graph()->get_return() != nullptr) {
// Find unused variables.
func_block->FindIsolatedNodes();
func_block_item->FindIsolatedNodes();
// Attach all isolated nodes.
func_block->AttachIsolatedNodesBeforeReturn();
func_block_item->AttachIsolatedNodesBeforeReturn();
}
}
@ -315,8 +314,8 @@ FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlo
py::str desc = python_adapter::CallPyModFn(ast_->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, node, ret[0], ret[1]);
MS_EXCEPTION(TypeError) << "Missing return statement in " << desc.cast<std::string>() << ".";
}
GenerateArgsDefaultValueForFunction(pFunBlock, node);
return pFunBlock;
GenerateArgsDefaultValueForFunction(func_block, node);
return func_block;
}
FunctionBlockPtr Parser::ParseStatements(FunctionBlockPtr block, const py::object &nodes) {
@ -461,14 +460,14 @@ FunctionBlockPtr Parser::ParseReturn(const FunctionBlockPtr &block, const py::ob
MS_LOG(DEBUG) << "Process ast return";
MS_EXCEPTION_IF_NULL(block);
// Create return valuenode
AnfNodePtr pReturnValueNode = NewValueNode(prim::kPrimReturn);
AnfNodePtr return_value_node = NewValueNode(prim::kPrimReturn);
// Parse the return Statements value
py::object value = python_adapter::GetPyObjAttr(node, "value");
AnfNodePtr pReturnStatementNode = ParseExprNode(block, value);
AnfNodePtr return_expr_node = ParseExprNode(block, value);
// Create the cnode
auto block_fg = block->func_graph();
CNodePtr pReturnCNode = block_fg->NewCNodeInOrder({pReturnValueNode, pReturnStatementNode});
block_fg->set_return(pReturnCNode);
CNodePtr return_node = block_fg->NewCNodeInOrder({return_value_node, return_expr_node});
block_fg->set_return(return_node);
return block;
}
@ -583,6 +582,7 @@ AnfNodePtr Parser::ParseNameConstant(const FunctionBlockPtr &, const py::object
errcode_ = PARSE_NODE_TYPE_UNKNOWN;
MS_LOG(EXCEPTION) << "Unsupported NameConstant type: " << (std::string)py::str(obj);
}
AnfNodePtr Parser::GenerateMakeTuple(const FunctionBlockPtr &block, const std::vector<AnfNodePtr> &element_nodes) {
MS_EXCEPTION_IF_NULL(block);
AnfNodePtr make_tuple_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE);
@ -1117,18 +1117,18 @@ FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object
py::object bodyNode = python_adapter::GetPyObjAttr(node, "body");
FunctionBlockPtr true_end = ParseStatements(true_block, bodyNode);
// If the return_ is set ,it has its own continuation block
// If the return_ is set, it has its own continuation block
if (true_end->func_graph()->get_return() == nullptr) {
true_end->Jump(after_block, nullptr);
true_end->Jump(after_block, {});
}
// Process the orelse branch
py::object orelseNode = python_adapter::GetPyObjAttr(node, "orelse");
FunctionBlockPtr false_end = ParseStatements(false_block, orelseNode);
// If the return_ is set ,it has its own continuation block
// If the return_ is set, it has its own continuation block
if (false_end->func_graph()->get_return() == nullptr) {
false_end->Jump(after_block, nullptr);
false_end->Jump(after_block, {});
}
block->ConditionalJump(bool_node, true_block, false_block);
@ -1158,7 +1158,7 @@ FunctionBlockPtr Parser::ParseWhile(const FunctionBlockPtr &block, const py::obj
body_block->AddPrevBlock(header_block);
after_block->AddPrevBlock(header_block);
block->Jump(header_block, nullptr);
block->Jump(header_block, {});
py::object test_node = python_adapter::GetPyObjAttr(node, "test");
AnfNodePtr condition_node = ParseExprNode(header_block, test_node);
@ -1171,7 +1171,7 @@ FunctionBlockPtr Parser::ParseWhile(const FunctionBlockPtr &block, const py::obj
py::object body_node = python_adapter::GetPyObjAttr(node, "body");
FunctionBlockPtr after_body = ParseStatements(body_block, body_node);
if (after_body->func_graph()->get_return() == nullptr) {
after_body->Jump(header_block, nullptr);
after_body->Jump(header_block, {});
}
header_block->Mature();
@ -1179,7 +1179,7 @@ FunctionBlockPtr Parser::ParseWhile(const FunctionBlockPtr &block, const py::obj
auto &end_block = loop_context.EndBlock();
if (end_block) {
// end_block exists if we encounter 'break' in loop body.
after_block->Jump(end_block, nullptr);
after_block->Jump(end_block, {});
end_block->Mature();
return end_block;
}
@ -1200,16 +1200,17 @@ CNodePtr Parser::GenerateCondInFor(const ParameterPtr &iter_param, const Functio
return header_block->func_graph()->NewCNodeInOrder({op_hasnext, iter_param});
}
FunctionBlockPtr Parser::GenerateBlockInFor(const TraceInfoPtr &trace_info) {
FunctionBlockPtr Parser::GenerateBlock(const TraceInfoPtr &trace_info) {
TraceGuard trace_guard(trace_info);
FunctionBlockPtr body_block = MakeFunctionBlock(*this);
return body_block;
FunctionBlockPtr block = MakeFunctionBlock(*this);
MS_EXCEPTION_IF_NULL(block);
return block;
}
int64_t Parser::GetForTransToWhileLoop() {
// int64 support 63bits positive num mostly.
constexpr auto kMaxNumLength = 10;
if (max_for_loop_count_str_.size() > kMaxNumLength || max_for_loop_count_str_.empty()) {
constexpr auto max_num_length = 10;
if (max_for_loop_count_str_.size() > max_num_length || max_for_loop_count_str_.empty()) {
return MAX_FOR_LOOP_COUNT;
}
if (std::any_of(max_for_loop_count_str_.begin(), max_for_loop_count_str_.end(),
@ -1222,6 +1223,7 @@ int64_t Parser::GetForTransToWhileLoop() {
ss >> loop_count;
return loop_count;
}
// A for loop will generate 3 functions :the test, the body, and the continuation
// for x in xs:
// body
@ -1260,10 +1262,10 @@ FunctionBlockPtr Parser::ParseFor(const FunctionBlockPtr &block, const py::objec
}
FunctionBlockPtr true_end = ParseForIter(true_block, node);
true_end->Jump(after_block, nullptr);
true_end->Jump(after_block, {});
FunctionBlockPtr false_end = ParseForLoop(false_block, node);
false_end->Jump(after_block, nullptr);
false_end->Jump(after_block, {});
block->ConditionalJump(bool_node, true_block, false_block);
after_block->Mature();
@ -1288,14 +1290,13 @@ FunctionBlockPtr Parser::ParseForIter(const FunctionBlockPtr &block, const py::o
// Generate the iterator apply
CNodePtr iter_apply = GenerateIteratorInFor(block, node, op_iter);
MS_EXCEPTION_IF_NULL(iter_apply);
FunctionBlockPtr header_block =
GenerateBlockInFor(std::make_shared<TraceForHeader>(block->func_graph()->debug_info()));
FunctionBlockPtr header_block = GenerateBlock(std::make_shared<TraceForHeader>(block->func_graph()->debug_info()));
MS_EXCEPTION_IF_NULL(header_block);
// Generate the hasnext apply which is a condition
ParameterPtr iter_param = header_block->func_graph()->add_parameter();
CNodePtr cond_apply = GenerateCondInFor(iter_param, header_block, op_hasnext);
// Generate the body of the for statement
FunctionBlockPtr body_block = GenerateBlockInFor(std::make_shared<TraceForBody>(block->func_graph()->debug_info()));
FunctionBlockPtr body_block = GenerateBlock(std::make_shared<TraceForBody>(block->func_graph()->debug_info()));
MS_EXCEPTION_IF_NULL(body_block);
body_block->AddPrevBlock(header_block);
// Generate the iterator next apply
@ -1323,7 +1324,7 @@ FunctionBlockPtr Parser::ParseForIter(const FunctionBlockPtr &block, const py::o
MS_EXCEPTION_IF_NULL(after_block);
after_block->AddPrevBlock(header_block);
block->Jump(header_block, iter_apply);
block->Jump(header_block, {iter_apply});
body_block->Mature();
header_block->ConditionalJump(cond_apply, body_block, after_block);
@ -1332,7 +1333,7 @@ FunctionBlockPtr Parser::ParseForIter(const FunctionBlockPtr &block, const py::o
py::object body_node = python_adapter::GetPyObjAttr(node, "body");
FunctionBlockPtr after_body_block = ParseStatements(body_block, body_node);
if (after_body_block->func_graph()->get_return() == nullptr) {
after_body_block->Jump(header_block, iter2_app);
after_body_block->Jump(header_block, {iter2_app});
}
header_block->Mature();
@ -1340,7 +1341,7 @@ FunctionBlockPtr Parser::ParseForIter(const FunctionBlockPtr &block, const py::o
auto &end_block = loop_context.EndBlock();
if (end_block) {
// end_block exists if we encounter 'break' in loop body.
after_block->Jump(end_block, nullptr);
after_block->Jump(end_block, {});
end_block->Mature();
return end_block;
}
@ -1377,8 +1378,7 @@ FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::o
CNodePtr len_iter = block->func_graph()->NewCNodeInOrder({scalar_to_tensor_node, scalar_len});
FunctionBlockPtr header_block =
GenerateBlockInFor(std::make_shared<TraceForHeader>(block->func_graph()->debug_info()));
FunctionBlockPtr header_block = GenerateBlock(std::make_shared<TraceForHeader>(block->func_graph()->debug_info()));
MS_EXCEPTION_IF_NULL(header_block);
// Create loop variable 'i'
ParameterPtr loop_var = header_block->func_graph()->add_parameter();
@ -1388,7 +1388,7 @@ FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::o
CNodePtr cond_node = header_block->func_graph()->NewCNodeInOrder({less_node, loop_var, len_iter});
// Generate the body of the for statement
FunctionBlockPtr body_block = GenerateBlockInFor(std::make_shared<TraceForBody>(block->func_graph()->debug_info()));
FunctionBlockPtr body_block = GenerateBlock(std::make_shared<TraceForBody>(block->func_graph()->debug_info()));
MS_EXCEPTION_IF_NULL(body_block);
body_block->AddPrevBlock(header_block);
// Create 'x = xs[i]'
@ -1419,7 +1419,7 @@ FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::o
CNodePtr zero_tensor =
block->func_graph()->NewCNodeInOrder({scalar_to_tensor_node, NewValueNode(static_cast<int64_t>(0))});
block->Jump(header_block, zero_tensor);
block->Jump(header_block, {zero_tensor});
body_block->Mature();
header_block->ConditionalJump(cond_node, body_block, after_block, false);
@ -1429,7 +1429,7 @@ FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::o
py::object body_node = python_adapter::GetPyObjAttr(node, "body");
FunctionBlockPtr after_body_block = ParseStatements(body_block, body_node);
if (after_body_block->func_graph()->get_return() == nullptr) {
after_body_block->Jump(header_block, loop_var_inc);
after_body_block->Jump(header_block, {loop_var_inc});
}
header_block->Mature();
@ -1437,7 +1437,7 @@ FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::o
auto &end_block = loop_context.EndBlock();
if (end_block) {
// end_block exists if we encounter 'break' in loop body.
after_block->Jump(end_block, nullptr);
after_block->Jump(end_block, {});
end_block->Mature();
return end_block;
}
@ -1489,6 +1489,155 @@ AnfNodePtr Parser::ParseIfExp(const FunctionBlockPtr &block, const py::object &n
return switch_app_call;
}
FunctionBlockPtr Parser::ParseListCompIter(const FunctionBlockPtr &block, const py::object &node,
const py::object &generator_node) {
// Create a header block.
FunctionBlockPtr top_block = GenerateBlock(std::make_shared<TraceListComp>(block->func_graph()->debug_info()));
// Handle iter attribute.
py::object iter_node = python_adapter::GetPyObjAttr(generator_node, "iter");
AnfNodePtr iter_anf_node = ParseExprNode(block, iter_node);
AnfNodePtr op_iter = top_block->MakeResolveOperation(NAMED_PRIMITIVE_ITER);
CNodePtr iter_apply = top_block->func_graph()->NewCNodeInOrder({op_iter, iter_anf_node});
// Create header graph.
FunctionBlockPtr list_header_block =
GenerateBlock(std::make_shared<TraceForHeader>(block->func_graph()->debug_info()));
list_header_block->AddPrevBlock(top_block);
// Create hasNext apply.
AnfNodePtr op_hasnext = top_block->MakeResolveOperation(NAMED_PRIMITIVE_HASNEXT);
ParameterPtr iter_param = list_header_block->func_graph()->add_parameter();
constexpr auto iter_param_name = "iter";
iter_param->set_name(iter_param_name);
iter_param->debug_info()->set_name(iter_param_name);
CNodePtr cond_apply = list_header_block->func_graph()->NewCNodeInOrder({op_hasnext, iter_param});
// Call the header graph with iter.
ParameterPtr list_param = list_header_block->func_graph()->add_parameter();
constexpr auto list_param_name = "list";
list_param->set_name(list_param_name);
list_param->debug_info()->set_name(list_param_name);
auto empty_list = std::vector<ValuePtr>();
AnfNodePtr empty_list_node = NewValueNode(std::make_shared<ValueList>(empty_list));
top_block->Jump(list_header_block, {iter_apply, empty_list_node});
// Create body graph.
FunctionBlockPtr list_body_block = GenerateBlock(std::make_shared<TraceForBody>(block->func_graph()->debug_info()));
list_body_block->AddPrevBlock(list_header_block);
AnfNodePtr op_next = top_block->MakeResolveOperation(NAMED_PRIMITIVE_NEXT);
CNodePtr next_apply = list_body_block->func_graph()->NewCNodeInOrder({op_next, iter_param});
AnfNodePtr op_getitem = top_block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
CNodePtr item_apply =
list_body_block->func_graph()->NewCNodeInOrder({op_getitem, next_apply, NewValueNode(static_cast<int64_t>(0))});
CNodePtr new_iter =
list_body_block->func_graph()->NewCNodeInOrder({op_getitem, next_apply, NewValueNode(static_cast<int64_t>(1))});
// Save the `target` in a variable.
py::object gen_target_node = python_adapter::GetPyObjAttr(generator_node, "target");
WriteAssignVars(list_body_block, gen_target_node, item_apply);
auto ifs_new_list = ParseListCompIfs(list_body_block, list_param, node, generator_node);
list_body_block->Jump(list_header_block, {new_iter, ifs_new_list});
// Create after graph.
FunctionBlockPtr list_after_block = GenerateBlock(std::make_shared<TraceForAfter>(block->func_graph()->debug_info()));
list_after_block->AddPrevBlock(list_header_block);
// Return the list in after graph.
list_after_block->func_graph()->set_output(list_param);
// Run the branches.
list_header_block->ConditionalJump(cond_apply, list_body_block, list_after_block);
top_block->Mature();
list_header_block->Mature();
list_body_block->Mature();
list_after_block->Mature();
return top_block;
}
AnfNodePtr Parser::ParseListCompIfs(const FunctionBlockPtr &list_body_block, const ParameterPtr &list_param,
const py::object &node, const py::object &generator_node) {
// Handle ifs attribute.
py::list ifs_node = python_adapter::GetPyObjAttr(generator_node, "ifs");
AnfNodePtr ifs_bool_node;
if (ifs_node.empty()) {
ifs_bool_node = NewValueNode(true);
} else {
ifs_bool_node = ProcessBoolOpValueList(list_body_block, ifs_node, AST_SUB_TYPE_AND);
}
// Create if-true graph.
FunctionBlockPtr if_true_block =
GenerateBlock(std::make_shared<TraceIfStmtTrueBranch>(list_body_block->func_graph()->debug_info()));
if_true_block->AddPrevBlock(list_body_block);
// Handle elt attribute in body block.
py::object elt_obj = python_adapter::GetPyObjAttr(node, "elt");
AnfNodePtr elt_node = ParseExprNode(list_body_block, elt_obj);
// Append the element.
auto list_append_op = prim::kPrimListAppend;
auto new_list = list_body_block->func_graph()->NewCNodeInOrder({NewValueNode(list_append_op), list_param, elt_node});
// Return new list in true branch graph.
if_true_block->func_graph()->set_output(new_list);
// Create if-false graph.
FunctionBlockPtr if_false_block =
GenerateBlock(std::make_shared<TraceIfStmtFalseBranch>(list_body_block->func_graph()->debug_info()));
if_false_block->AddPrevBlock(list_body_block);
// Return original list in false branch graph.
if_false_block->func_graph()->set_output(list_param);
// We don't want to create a header graph, where to get and wrap the result of Switch().
// So just call ConditionalJump() to set Switch() as output, and reset it later, as tricky.
list_body_block->ConditionalJump(ifs_bool_node, if_true_block, if_false_block);
// Output is Switch() result, i.e. updated list.
auto switch_apply_node = list_body_block->func_graph()->output();
auto ifs_new_list = switch_apply_node;
// Since we call ConditionalJump() above, to reset the Return as null before call Jump().
list_body_block->func_graph()->set_return(nullptr);
if_true_block->Mature();
if_false_block->Mature();
return ifs_new_list;
}
// A ListComp contains: `elt` and `generators`.
// `generators` contains: `target`, `iter` and `ifs`.
// For example:
// [x * x for x in range(0, 10) if x % 2 == 0]
// It is compiled to be following statement:
// list = []
// for x in range(0, 10):
// if x % 2 == 0:
// list.append(x * x)
// return list
AnfNodePtr Parser::ParseListComp(const FunctionBlockPtr &block, const py::object &node) {
MS_LOG(DEBUG) << "Process ast ListComp";
MS_EXCEPTION_IF_NULL(block);
// Handle generators attribute.
py::list generators_node = python_adapter::GetPyObjAttr(node, "generators");
if (generators_node.size() != 1) {
MS_EXCEPTION(TypeError) << "The `generators` supports one `comprehension` in ListComp/GeneratorExp, but got "
<< generators_node.size() << " comprehensions.";
}
py::object generator_node = generators_node[0];
auto generator_node_type = ast_->GetNodeType(generator_node);
auto generator_node_name = generator_node_type->node_name();
constexpr auto comprehension_name = "comprehension";
if (generator_node_name != comprehension_name) {
MS_LOG(EXCEPTION) << "Generator node name should be " << comprehension_name << ", but got " << generator_node_name;
}
// Parse ListComp's `iter` and add `elt` in it.
auto top_block = ParseListCompIter(block, node, generator_node);
// Call the top graph and return the list.
auto call_function_anf_node = NewValueNode(top_block->func_graph());
std::vector<AnfNodePtr> func_call_nodes;
func_call_nodes.push_back(call_function_anf_node);
AnfNodePtr output = block->func_graph()->NewCNodeInOrder(func_call_nodes);
return output;
}
void Parser::HandleAssignName(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node) {
MS_EXCEPTION_IF_NULL(block);
MS_EXCEPTION_IF_NULL(assigned_node);
@ -1644,7 +1793,7 @@ FunctionBlockPtr Parser::ParseBreak(const FunctionBlockPtr &block, const py::obj
loop.end = MakeFunctionBlock(*this);
}
// Jump to the end_block.
block->Jump(loop.end, nullptr);
block->Jump(loop.end, {});
return block;
}
@ -1655,7 +1804,11 @@ FunctionBlockPtr Parser::ParseContinue(const FunctionBlockPtr &block, const py::
}
// Jump to the header of the loop with iterator called.
Loop &loop = loops_.top();
block->Jump(loop.header, loop.iterator);
std::vector<AnfNodePtr> args;
if (loop.iterator != nullptr) {
args.emplace_back(loop.iterator);
}
block->Jump(loop.header, args);
return block;
}

View File

@ -38,19 +38,19 @@ namespace parse {
// Parse status define
enum ParseStatusCode : int64_t {
PARSE_SUCCESS = 0,
PARSE_FUNCTION_IS_NULL, // python function is null
PARSE_PARAMETER_INVALID, // parameter is invalid
PARSE_NO_RETURN, // function no return node
PARSE_NODE_TYPE_NO_MATCH, // ast node type is error
PARSE_NODE_TYPE_UNKNOWN, // node type is unknown
PARSE_NODE_METHOD_UNSUPPORTED, // no method to parse the node
PARSE_DONT_RESOLVE_SYMBOL, // can't resolve the string
PARSE_NOT_SUPPORTED_COMPARE_EXPR, // the comparison is not supported
PARSE_FUNCTION_IS_NULL, // Python function is null
PARSE_PARAMETER_INVALID, // Parameter is invalid
PARSE_NO_RETURN, // Function no return node
PARSE_NODE_TYPE_NO_MATCH, // Ast node type is error
PARSE_NODE_TYPE_UNKNOWN, // Node type is unknown
PARSE_NODE_METHOD_UNSUPPORTED, // No method to parse the node
PARSE_DONT_RESOLVE_SYMBOL, // Can't resolve the string
PARSE_NOT_SUPPORTED_COMPARE_EXPR, // The comparison is not supported
PARSE_FAILURE = 0xFF
};
// max loop count of for statement, when loop count is less then this value, the for loop will be unrolled, otherwise it
// will be sunk(i.e. not unrolled)
// Max loop count of for statement, when loop count is less then this value, the for loop will be unrolled, otherwise it
// will be sunk(i.e. not unrolled)
// NOTE: Since when the for loop was unrolled, it depends backend operators `tuple_getitem` and `scalar_add` which were
// not implemented, so here set MAX_FOR_LOOP_COUNT to int64_t max limit to override default value `600`. This will make
// the for loop will always be unrolled, but don't worry about the memory were exhausted, an exception will be raised
@ -97,7 +97,7 @@ class Parser {
FuncGraphPtr func_graph() const { return func_graph_; }
ParseStatusCode errcode() const { return errcode_; }
std::shared_ptr<ParseAst> ast() const { return ast_; }
// get location info from the ast node
// Get location info from the ast node
LocationPtr GetLocation(const py::object &node) const;
static void InitParserEnvironment(const py::object &obj);
static void CleanParserResource();
@ -105,114 +105,118 @@ class Parser {
static void UpdateTopFuncGraph(const FuncGraphPtr &func_graph);
private:
// process the stmt node method list
// Process the stmt node method list
FunctionBlockPtr ParseReturn(const FunctionBlockPtr &block, const py::object &node);
// parse expression
// Parse expression
FunctionBlockPtr ParseExpr(const FunctionBlockPtr &block, const py::object &node);
// process a if statement
// Process a if statement
FunctionBlockPtr ParseIf(const FunctionBlockPtr &block, const py::object &node);
// process a while statement
// Process a while statement
FunctionBlockPtr ParseWhile(const FunctionBlockPtr &block, const py::object &node);
// process a for statement
// Process a for statement
FunctionBlockPtr ParseFor(const FunctionBlockPtr &block, const py::object &node);
FunctionBlockPtr ParseForIter(const FunctionBlockPtr &block, const py::object &node);
FunctionBlockPtr ParseForLoop(const FunctionBlockPtr &block, const py::object &node);
// process a function def statement
// Process a function def statement
FunctionBlockPtr ParseFunctionDef(const FunctionBlockPtr &block, const py::object &node);
// process a augment assign
// Process a augment assign
FunctionBlockPtr ParseAugAssign(const FunctionBlockPtr &block, const py::object &node);
// process a global declaration
// Process a global declaration
FunctionBlockPtr ParseGlobal(const FunctionBlockPtr &block, const py::object &node);
// process assign statement
// Process assign statement
FunctionBlockPtr ParseAssign(const FunctionBlockPtr &block, const py::object &node);
// process break statement
// Process break statement
FunctionBlockPtr ParseBreak(const FunctionBlockPtr &block, const py::object &node);
// process continue statement
// Process continue statement
FunctionBlockPtr ParseContinue(const FunctionBlockPtr &block, const py::object &node);
// process pass statement
// Process pass statement
FunctionBlockPtr ParsePass(const FunctionBlockPtr &block, const py::object &node);
// process the expr and slice node method list
// Process the expr and slice node method list
AnfNodePtr ParseBinOp(const FunctionBlockPtr &block, const py::object &node);
// process a variable name
// Process a variable name
AnfNodePtr ParseName(const FunctionBlockPtr &block, const py::object &node);
// process NoneType
// Process NoneType
AnfNodePtr ParseNone(const FunctionBlockPtr &block, const py::object &node);
// process Ellipsis
// Process Ellipsis
AnfNodePtr ParseEllipsis(const FunctionBlockPtr &block, const py::object &node);
// process a integer or float number
// Process a integer or float number
AnfNodePtr ParseNum(const FunctionBlockPtr &block, const py::object &node);
// process a string variable
// Process a string variable
AnfNodePtr ParseStr(const FunctionBlockPtr &block, const py::object &node);
// process a Constant
// Process a Constant
AnfNodePtr ParseConstant(const FunctionBlockPtr &block, const py::object &node);
// process a name
// Process a name
AnfNodePtr ParseNameConstant(const FunctionBlockPtr &block, const py::object &node);
// process a function call
// Process a function call
AnfNodePtr ParseCall(const FunctionBlockPtr &block, const py::object &node);
// process function 'super'
// Process function 'super'
AnfNodePtr ParseSuper(const FunctionBlockPtr &block, const py::list &args);
// process the if expression
// Process the if expression
AnfNodePtr ParseIfExp(const FunctionBlockPtr &block, const py::object &node);
// process class type define
// Process class type define
AnfNodePtr ParseAttribute(const FunctionBlockPtr &block, const py::object &node);
// process a compare expression
// Process a compare expression
AnfNodePtr ParseCompare(const FunctionBlockPtr &block, const py::object &node);
// process a bool operation
// Process a bool operation
AnfNodePtr ParseBoolOp(const FunctionBlockPtr &block, const py::object &node);
// process a lambda operation
// Process a lambda operation
AnfNodePtr ParseLambda(const FunctionBlockPtr &block, const py::object &node);
// process a tuple
// Process a tuple
AnfNodePtr ParseTuple(const FunctionBlockPtr &block, const py::object &node);
// process a tuple
// Process a tuple
AnfNodePtr ParseList(const FunctionBlockPtr &block, const py::object &node);
// process a tuple
// Process a tuple
AnfNodePtr ParseSubscript(const FunctionBlockPtr &block, const py::object &node);
// process a slice
// Process a slice
AnfNodePtr ParseSlice(const FunctionBlockPtr &block, const py::object &node);
// process a extslice
// Process a extslice
AnfNodePtr ParseExtSlice(const FunctionBlockPtr &block, const py::object &node);
// process a tuple
// Process a tuple
AnfNodePtr ParseIndex(const FunctionBlockPtr &block, const py::object &node);
// process a unaryop
// Process a unaryop
AnfNodePtr ParseUnaryOp(const FunctionBlockPtr &block, const py::object &node);
// process a dict ast node expression
// Process a dict ast node expression
AnfNodePtr ParseDict(const FunctionBlockPtr &block, const py::object &node);
// generate argument nodes for ast function node
// Process ListComp expression
AnfNodePtr ParseListComp(const FunctionBlockPtr &block, const py::object &node);
FunctionBlockPtr ParseListCompIter(const FunctionBlockPtr &block, const py::object &node,
const py::object &generator_node);
AnfNodePtr ParseListCompIfs(const FunctionBlockPtr &list_body_block, const ParameterPtr &list_param,
const py::object &node, const py::object &generator_node);
// Generate argument nodes for ast function node
void GenerateArgsNodeForFunction(const FunctionBlockPtr &block, const py::object &function_node);
// generate argument default value for ast function node
// Generate argument default value for ast function node
void GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block, const py::object &function_node);
// parse ast function node
// Parse ast function node
FunctionBlockPtr ParseFunction(const py::object &function_node, const FunctionBlockPtr &block = nullptr);
// parse ast statements
// Parse ast statements
FunctionBlockPtr ParseStatements(FunctionBlockPtr block, const py::object &stmt_node);
// parse one ast statement node
// Parse one ast statement node
FunctionBlockPtr ParseStatement(const FunctionBlockPtr &block, const py::object &node);
// parse an ast expression node
// Parse an ast expression node
AnfNodePtr ParseExprNode(const FunctionBlockPtr &block, const py::object &node);
void MakeConditionBlocks(const FunctionBlockPtr &block, const FunctionBlockPtr &trueBlock,
const FunctionBlockPtr &falseBlock);
void RemoveUnnecessaryPhis();
// write a new var
// Write a new var
void WriteAssignVars(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &value_node);
// assign value to single variable name
// Assign value to single variable name
void HandleAssignName(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node);
// assign value to tuple
// Assign value to tuple
void HandleAssignTuple(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node);
// assign value to class member
// Assign value to class member
void HandleAssignClassMember(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node);
// assign value to subscript
// Assign value to subscript
void HandleAssignSubscript(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node);
// process a bool operation value list
// Process a bool operation value list
AnfNodePtr ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, AstSubType mode);
CNodePtr GenerateIteratorInFor(const FunctionBlockPtr &block, const pybind11::object &node,
@ -221,7 +225,7 @@ class Parser {
CNodePtr GenerateCondInFor(const ParameterPtr &iter_param, const FunctionBlockPtr &header_block,
const AnfNodePtr &op_hasnext);
FunctionBlockPtr GenerateBlockInFor(const TraceInfoPtr &trace_info);
FunctionBlockPtr GenerateBlock(const TraceInfoPtr &trace_info);
bool ParseKeywordsInCall(const FunctionBlockPtr &block, const py::object &node,
std::vector<AnfNodePtr> *packed_arguments);
@ -249,27 +253,27 @@ class Parser {
func_block_list_.push_back(block);
return block;
}
// return a make tuple for input elements list
// Return a make tuple for input elements list
AnfNodePtr GenerateMakeTuple(const FunctionBlockPtr &block, const std::vector<AnfNodePtr> &element_nodes);
int64_t GetForTransToWhileLoop();
// shared_ptr will be hold by GraphManager, so just hold a weak ref here.
// The shared_ptr will be hold by GraphManager, so just hold a weak ref here.
static FuncGraphWeakPtr top_func_graph_;
// Python function id, used to indicate whether two CNodes come from the same Python function
const std::shared_ptr<ParseAst> &ast_;
FuncGraphPtr func_graph_;
// error code setwhen parsing ast tree
// Error code setwhen parsing ast tree
ParseStatusCode errcode_;
// hold all reference for FunctionBlock in this round of parsing,
// Hold all reference for FunctionBlock in this round of parsing,
// so in FunctionBlock class we can use FunctionBlock* in member
// pre_blocks_ and jumps_ to break reference cycle.
std::vector<FunctionBlockPtr> func_block_list_;
using pStmtFunc = FunctionBlockPtr (Parser::*)(const FunctionBlockPtr &block, const py::object &node);
using pExprFunc = AnfNodePtr (Parser::*)(const FunctionBlockPtr &block, const py::object &node);
// define the function map to parse ast Statement
// Define the function map to parse ast Statement
std::map<std::string, pStmtFunc> stmt_method_map_;
// define the function map to parse ast expression
// Define the function map to parse ast expression
std::map<std::string, pExprFunc> expr_method_map_;
// Save current loops to support 'continue', 'break' statement.
std::stack<Loop> loops_;
@ -350,10 +354,10 @@ class ParseAst {
bool IsClassMember(const py::object &node);
private:
// save obj,eg: class instance or function
// Save obj,eg: class instance or function
py::object obj_;
// function or class method.
// Function or class method.
py::function function_;
py::object ast_tree_;
@ -369,7 +373,7 @@ class ParseAst {
int64_t function_line_offset_;
};
// update the graph flags
// Update the graph flags
bool UpdateFuncGraphFlags(const py::object &obj, const FuncGraphPtr &func_graph);
AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr &param);

View File

@ -271,10 +271,14 @@ const AbstractBasePtr AbstractSequeue::operator[](const std::size_t &dim) const
std::string AbstractSequeue::ToString() const {
std::ostringstream buffer;
int64_t i = 0;
size_t i = 0;
size_t size = elements_.size();
for (const auto &ele : elements_) {
MS_EXCEPTION_IF_NULL(ele);
buffer << "element[" << i << "]: " << ele->ToString() << ",";
buffer << "element[" << i << "]: " << ele->ToString();
if (i < size - 1) {
buffer << ", ";
}
i++;
}
return buffer.str();

View File

@ -318,8 +318,11 @@ AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePt
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
AbstractListPtr list = CheckArg<AbstractList>(op_name, args_spec_list, 0);
(void)AbstractJoin(list->elements());
return list;
AbstractBasePtr item = dyn_cast<AbstractBase>(args_spec_list[1]);
MS_EXCEPTION_IF_NULL(item);
auto new_list = AbstractBasePtrList(list->elements());
new_list.emplace_back(item);
return std::make_shared<AbstractList>(new_list);
}
AbstractBasePtr InferImplTupleLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,

View File

@ -430,6 +430,14 @@ class TraceOpt : public TraceInfo {
~TraceOpt() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceOpt>(*shared_from_base<TraceOpt>()); }
};
class TraceListComp : public TraceInfo {
public:
explicit TraceListComp(const DebugInfoPtr &info) : TraceInfo(info, "ListComp", "G-") {}
MS_DECLARE_PARENT(TraceListComp, TraceInfo);
~TraceListComp() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceListComp>(*shared_from_base<TraceListComp>()); }
};
} // namespace mindspore
#endif // MINDSPORE_CORE_UTILS_TRACE_INFO_H_

View File

@ -0,0 +1,67 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test ListComp and GeneratorExp """
import pytest
from mindspore import context, ms_function
@ms_function
def get_list_comp_1():
l = [x for x in range(1, 6)]
return l
@ms_function
def get_list_comp_2():
l = [x * x for x in range(1, 6)]
return l
@ms_function
def get_list_comp_3():
l = [x * x for x in range(1, 11) if x % 2 == 0]
return l
@ms_function
def get_list_comp_4():
l = [x * x for x in range(1, 11) if x > 5 if x % 2 == 0]
return l
@ms_function
def get_list_comp_5():
# Create a ListComp with multiple comprehension.
# Not supported.
l = [y for x in ((1, 2), (3, 4), (5, 6)) for y in x] # [1, 2, 3, 4, 5, 6]
return l
@ms_function
def get_generator_exp_1():
t = (x for x in range(1, 6))
return t
@ms_function
def get_generator_exp_2():
t = (x * x for x in range(1, 11) if x > 5 if x % 2 == 0)
return t
def test_list_comp():
context.set_context(mode=context.GRAPH_MODE)
assert get_list_comp_1() == (1, 2, 3, 4, 5)
assert get_list_comp_2() == (1, 4, 9, 16, 25)
assert get_list_comp_3() == (4, 16, 36, 64, 100)
assert get_list_comp_4() == (36, 64, 100)
with pytest.raises(TypeError) as ex:
get_list_comp_5()
assert "The `generators` supports one `comprehension` in ListComp/GeneratorExp" in str(ex.value)
assert get_generator_exp_1() == (1, 2, 3, 4, 5)
assert get_generator_exp_2() == (36, 64, 100)