forked from mindspore-Ecosystem/mindspore
Support ListComp and GeneratorExp in Graph Mode.
This commit is contained in:
parent
dfd3f92858
commit
aef396f9f2
|
@ -20,6 +20,7 @@
|
|||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pipeline/jit/parse/resolve.h"
|
||||
|
@ -329,10 +330,10 @@ bool FunctionBlock::CollectRemovablePhi(const ParameterPtr &phi) {
|
|||
|
||||
// A block should be marked matured if its predecessor blocks have been processed
|
||||
void FunctionBlock::Mature() {
|
||||
const auto &graphParamVec = func_graph_->parameters();
|
||||
for (auto ¶mItr : graphParamVec) {
|
||||
MS_EXCEPTION_IF_NULL(paramItr);
|
||||
auto param = paramItr->cast<ParameterPtr>();
|
||||
const auto &graph_params = func_graph_->parameters();
|
||||
for (auto ¶m_itr : graph_params) {
|
||||
MS_EXCEPTION_IF_NULL(param_itr);
|
||||
auto param = param_itr->cast<ParameterPtr>();
|
||||
if (phi_nodes_.find(param) != phi_nodes_.cend()) {
|
||||
SetPhiArgument(param);
|
||||
}
|
||||
|
@ -356,7 +357,7 @@ CNodePtr FunctionBlock::ForceToWhileCond(const AnfNodePtr &cond) {
|
|||
}
|
||||
|
||||
// Perform a jump from this block to target block
|
||||
void FunctionBlock::Jump(const FunctionBlockPtr &target_block, const AnfNodePtr &node) {
|
||||
void FunctionBlock::Jump(const FunctionBlockPtr &target_block, const std::vector<AnfNodePtr> &args) {
|
||||
MS_EXCEPTION_IF_NULL(target_block);
|
||||
if (func_graph_->get_return() != nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failure: have return node! NodeInfo: "
|
||||
|
@ -364,9 +365,7 @@ void FunctionBlock::Jump(const FunctionBlockPtr &target_block, const AnfNodePtr
|
|||
}
|
||||
std::vector<AnfNodePtr> input_nodes;
|
||||
input_nodes.emplace_back(NewValueNode(target_block->func_graph()));
|
||||
if (node != nullptr) {
|
||||
input_nodes.emplace_back(node);
|
||||
}
|
||||
(void)std::copy(args.begin(), args.end(), std::back_inserter(input_nodes));
|
||||
|
||||
CNodePtr jump = func_graph_->NewCNodeInOrder(input_nodes);
|
||||
jumps_[target_block.get()] = jump;
|
||||
|
|
|
@ -57,7 +57,7 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
|
|||
void Mature();
|
||||
CNodePtr ForceToBoolNode(const AnfNodePtr &cond);
|
||||
CNodePtr ForceToWhileCond(const AnfNodePtr &cond);
|
||||
void Jump(const FunctionBlockPtr &block, const AnfNodePtr &node);
|
||||
void Jump(const FunctionBlockPtr &block, const std::vector<AnfNodePtr> &args);
|
||||
AnfNodePtr SearchReplaceNode(const std::string &var, const ParameterPtr &phi);
|
||||
void ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr &trueBlock, const FunctionBlockPtr &falseBlock,
|
||||
bool unroll_loop = true);
|
||||
|
|
|
@ -130,6 +130,8 @@ void Parser::BuildMethodMap() {
|
|||
expr_method_map_["UnaryOp"] = &Parser::ParseUnaryOp;
|
||||
expr_method_map_["Dict"] = &Parser::ParseDict;
|
||||
expr_method_map_["Ellipsis"] = &Parser::ParseEllipsis;
|
||||
expr_method_map_["ListComp"] = &Parser::ParseListComp;
|
||||
expr_method_map_["GeneratorExp"] = &Parser::ParseListComp; // We treat 'GeneratorExp' the same as 'ListComp'.
|
||||
}
|
||||
|
||||
void Parser::UpdateTopFuncGraph(const FuncGraphPtr &func_graph) { top_func_graph_ = FuncGraphWeakPtr(func_graph); }
|
||||
|
@ -156,8 +158,8 @@ void CheckFuncReturn(const FuncGraphPtr &fn, const std::shared_ptr<ParseAst> &as
|
|||
}
|
||||
py::object node = ast->GetAstNode();
|
||||
py::list ret = ast->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node);
|
||||
constexpr auto kMinListSize = 2;
|
||||
if (ret.size() < kMinListSize) {
|
||||
constexpr auto min_list_size = 2;
|
||||
if (ret.size() < min_list_size) {
|
||||
MS_LOG(EXCEPTION) << "list size:" << ret.size() << " is less than 2.";
|
||||
}
|
||||
py::str desc =
|
||||
|
@ -169,18 +171,15 @@ void CheckFuncReturn(const FuncGraphPtr &fn, const std::shared_ptr<ParseAst> &as
|
|||
FuncGraphPtr Parser::ParseFuncGraph() {
|
||||
// Get ast FunctionDef node
|
||||
py::object node = ast_->GetAstNode();
|
||||
FunctionBlockPtr pFnBlock = ParseFunction(node);
|
||||
FunctionBlockPtr fn_block = ParseFunction(node);
|
||||
if (errcode() != PARSE_SUCCESS) {
|
||||
MS_LOG(ERROR) << "Parse function error, code is " << errcode();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
RemoveUnnecessaryPhis();
|
||||
|
||||
MS_EXCEPTION_IF_NULL(pFnBlock);
|
||||
CheckFuncReturn(pFnBlock->func_graph(), ast_);
|
||||
|
||||
return pFnBlock->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(fn_block);
|
||||
CheckFuncReturn(fn_block->func_graph(), ast_);
|
||||
return fn_block->func_graph();
|
||||
}
|
||||
|
||||
void Parser::GenerateArgsNodeForFunction(const FunctionBlockPtr &block, const py::object &fn_node) {
|
||||
|
@ -261,14 +260,14 @@ FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlo
|
|||
// The node created in the parsefunction context, will inherit the scope created using scope_guard
|
||||
ScopeGuard scope_guard(scope);
|
||||
TraceGuard trace_guard(data_converter::GetObjKey(ast_->obj())[0], GetLocation(node));
|
||||
FunctionBlockPtr pFunBlock = MakeFunctionBlock(*this);
|
||||
FunctionBlockPtr func_block = MakeFunctionBlock(*this);
|
||||
if (block != nullptr) {
|
||||
pFunBlock->AddPrevBlock(block);
|
||||
func_block->AddPrevBlock(block);
|
||||
} else {
|
||||
func_graph_ = pFunBlock->func_graph();
|
||||
func_graph_ = func_block->func_graph();
|
||||
}
|
||||
pFunBlock->Mature();
|
||||
auto current_fg = pFunBlock->func_graph();
|
||||
func_block->Mature();
|
||||
auto current_fg = func_block->func_graph();
|
||||
auto function_name = py::cast<std::string>(python_adapter::GetPyObjAttr(node, "name"));
|
||||
MS_LOG(DEBUG) << "The function name is " << function_name;
|
||||
current_fg->debug_info()->set_name(function_name);
|
||||
|
@ -286,27 +285,27 @@ FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlo
|
|||
MS_LOG(ERROR) << "Set flags failed";
|
||||
return nullptr;
|
||||
}
|
||||
GenerateArgsNodeForFunction(pFunBlock, node);
|
||||
GenerateArgsNodeForFunction(func_block, node);
|
||||
|
||||
// When parsing the top graph of construct, save the top graph
|
||||
if (GetTopFuncGraph() == nullptr) {
|
||||
UpdateTopFuncGraph(pFunBlock->func_graph());
|
||||
UpdateTopFuncGraph(func_block->func_graph());
|
||||
}
|
||||
|
||||
// Save the function node to block
|
||||
pFunBlock->WriteVariable(function_name, NewValueNode(current_fg));
|
||||
func_block->WriteVariable(function_name, NewValueNode(current_fg));
|
||||
|
||||
py::object funcObj = python_adapter::GetPyObjAttr(node, "body");
|
||||
(void)ParseStatements(pFunBlock, funcObj);
|
||||
(void)ParseStatements(func_block, funcObj);
|
||||
|
||||
// Add unused variables as isolate nodes.
|
||||
for (auto &func_block : func_block_list_) {
|
||||
MS_EXCEPTION_IF_NULL(func_block);
|
||||
if (func_block->func_graph()->get_return() != nullptr) {
|
||||
for (auto &func_block_item : func_block_list_) {
|
||||
MS_EXCEPTION_IF_NULL(func_block_item);
|
||||
if (func_block_item->func_graph()->get_return() != nullptr) {
|
||||
// Find unused variables.
|
||||
func_block->FindIsolatedNodes();
|
||||
func_block_item->FindIsolatedNodes();
|
||||
// Attach all isolated nodes.
|
||||
func_block->AttachIsolatedNodesBeforeReturn();
|
||||
func_block_item->AttachIsolatedNodesBeforeReturn();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -315,8 +314,8 @@ FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlo
|
|||
py::str desc = python_adapter::CallPyModFn(ast_->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, node, ret[0], ret[1]);
|
||||
MS_EXCEPTION(TypeError) << "Missing return statement in " << desc.cast<std::string>() << ".";
|
||||
}
|
||||
GenerateArgsDefaultValueForFunction(pFunBlock, node);
|
||||
return pFunBlock;
|
||||
GenerateArgsDefaultValueForFunction(func_block, node);
|
||||
return func_block;
|
||||
}
|
||||
|
||||
FunctionBlockPtr Parser::ParseStatements(FunctionBlockPtr block, const py::object &nodes) {
|
||||
|
@ -461,14 +460,14 @@ FunctionBlockPtr Parser::ParseReturn(const FunctionBlockPtr &block, const py::ob
|
|||
MS_LOG(DEBUG) << "Process ast return";
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
// Create return valuenode
|
||||
AnfNodePtr pReturnValueNode = NewValueNode(prim::kPrimReturn);
|
||||
AnfNodePtr return_value_node = NewValueNode(prim::kPrimReturn);
|
||||
// Parse the return Statements value
|
||||
py::object value = python_adapter::GetPyObjAttr(node, "value");
|
||||
AnfNodePtr pReturnStatementNode = ParseExprNode(block, value);
|
||||
AnfNodePtr return_expr_node = ParseExprNode(block, value);
|
||||
// Create the cnode
|
||||
auto block_fg = block->func_graph();
|
||||
CNodePtr pReturnCNode = block_fg->NewCNodeInOrder({pReturnValueNode, pReturnStatementNode});
|
||||
block_fg->set_return(pReturnCNode);
|
||||
CNodePtr return_node = block_fg->NewCNodeInOrder({return_value_node, return_expr_node});
|
||||
block_fg->set_return(return_node);
|
||||
return block;
|
||||
}
|
||||
|
||||
|
@ -583,6 +582,7 @@ AnfNodePtr Parser::ParseNameConstant(const FunctionBlockPtr &, const py::object
|
|||
errcode_ = PARSE_NODE_TYPE_UNKNOWN;
|
||||
MS_LOG(EXCEPTION) << "Unsupported NameConstant type: " << (std::string)py::str(obj);
|
||||
}
|
||||
|
||||
AnfNodePtr Parser::GenerateMakeTuple(const FunctionBlockPtr &block, const std::vector<AnfNodePtr> &element_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
AnfNodePtr make_tuple_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE);
|
||||
|
@ -1117,18 +1117,18 @@ FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object
|
|||
py::object bodyNode = python_adapter::GetPyObjAttr(node, "body");
|
||||
FunctionBlockPtr true_end = ParseStatements(true_block, bodyNode);
|
||||
|
||||
// If the return_ is set ,it has its own continuation block
|
||||
// If the return_ is set, it has its own continuation block
|
||||
if (true_end->func_graph()->get_return() == nullptr) {
|
||||
true_end->Jump(after_block, nullptr);
|
||||
true_end->Jump(after_block, {});
|
||||
}
|
||||
|
||||
// Process the orelse branch
|
||||
py::object orelseNode = python_adapter::GetPyObjAttr(node, "orelse");
|
||||
FunctionBlockPtr false_end = ParseStatements(false_block, orelseNode);
|
||||
|
||||
// If the return_ is set ,it has its own continuation block
|
||||
// If the return_ is set, it has its own continuation block
|
||||
if (false_end->func_graph()->get_return() == nullptr) {
|
||||
false_end->Jump(after_block, nullptr);
|
||||
false_end->Jump(after_block, {});
|
||||
}
|
||||
|
||||
block->ConditionalJump(bool_node, true_block, false_block);
|
||||
|
@ -1158,7 +1158,7 @@ FunctionBlockPtr Parser::ParseWhile(const FunctionBlockPtr &block, const py::obj
|
|||
|
||||
body_block->AddPrevBlock(header_block);
|
||||
after_block->AddPrevBlock(header_block);
|
||||
block->Jump(header_block, nullptr);
|
||||
block->Jump(header_block, {});
|
||||
|
||||
py::object test_node = python_adapter::GetPyObjAttr(node, "test");
|
||||
AnfNodePtr condition_node = ParseExprNode(header_block, test_node);
|
||||
|
@ -1171,7 +1171,7 @@ FunctionBlockPtr Parser::ParseWhile(const FunctionBlockPtr &block, const py::obj
|
|||
py::object body_node = python_adapter::GetPyObjAttr(node, "body");
|
||||
FunctionBlockPtr after_body = ParseStatements(body_block, body_node);
|
||||
if (after_body->func_graph()->get_return() == nullptr) {
|
||||
after_body->Jump(header_block, nullptr);
|
||||
after_body->Jump(header_block, {});
|
||||
}
|
||||
|
||||
header_block->Mature();
|
||||
|
@ -1179,7 +1179,7 @@ FunctionBlockPtr Parser::ParseWhile(const FunctionBlockPtr &block, const py::obj
|
|||
auto &end_block = loop_context.EndBlock();
|
||||
if (end_block) {
|
||||
// end_block exists if we encounter 'break' in loop body.
|
||||
after_block->Jump(end_block, nullptr);
|
||||
after_block->Jump(end_block, {});
|
||||
end_block->Mature();
|
||||
return end_block;
|
||||
}
|
||||
|
@ -1200,16 +1200,17 @@ CNodePtr Parser::GenerateCondInFor(const ParameterPtr &iter_param, const Functio
|
|||
return header_block->func_graph()->NewCNodeInOrder({op_hasnext, iter_param});
|
||||
}
|
||||
|
||||
FunctionBlockPtr Parser::GenerateBlockInFor(const TraceInfoPtr &trace_info) {
|
||||
FunctionBlockPtr Parser::GenerateBlock(const TraceInfoPtr &trace_info) {
|
||||
TraceGuard trace_guard(trace_info);
|
||||
FunctionBlockPtr body_block = MakeFunctionBlock(*this);
|
||||
return body_block;
|
||||
FunctionBlockPtr block = MakeFunctionBlock(*this);
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
return block;
|
||||
}
|
||||
|
||||
int64_t Parser::GetForTransToWhileLoop() {
|
||||
// int64 support 63bits positive num mostly.
|
||||
constexpr auto kMaxNumLength = 10;
|
||||
if (max_for_loop_count_str_.size() > kMaxNumLength || max_for_loop_count_str_.empty()) {
|
||||
constexpr auto max_num_length = 10;
|
||||
if (max_for_loop_count_str_.size() > max_num_length || max_for_loop_count_str_.empty()) {
|
||||
return MAX_FOR_LOOP_COUNT;
|
||||
}
|
||||
if (std::any_of(max_for_loop_count_str_.begin(), max_for_loop_count_str_.end(),
|
||||
|
@ -1222,6 +1223,7 @@ int64_t Parser::GetForTransToWhileLoop() {
|
|||
ss >> loop_count;
|
||||
return loop_count;
|
||||
}
|
||||
|
||||
// A for loop will generate 3 functions :the test, the body, and the continuation
|
||||
// for x in xs:
|
||||
// body
|
||||
|
@ -1260,10 +1262,10 @@ FunctionBlockPtr Parser::ParseFor(const FunctionBlockPtr &block, const py::objec
|
|||
}
|
||||
|
||||
FunctionBlockPtr true_end = ParseForIter(true_block, node);
|
||||
true_end->Jump(after_block, nullptr);
|
||||
true_end->Jump(after_block, {});
|
||||
|
||||
FunctionBlockPtr false_end = ParseForLoop(false_block, node);
|
||||
false_end->Jump(after_block, nullptr);
|
||||
false_end->Jump(after_block, {});
|
||||
|
||||
block->ConditionalJump(bool_node, true_block, false_block);
|
||||
after_block->Mature();
|
||||
|
@ -1288,14 +1290,13 @@ FunctionBlockPtr Parser::ParseForIter(const FunctionBlockPtr &block, const py::o
|
|||
// Generate the iterator apply
|
||||
CNodePtr iter_apply = GenerateIteratorInFor(block, node, op_iter);
|
||||
MS_EXCEPTION_IF_NULL(iter_apply);
|
||||
FunctionBlockPtr header_block =
|
||||
GenerateBlockInFor(std::make_shared<TraceForHeader>(block->func_graph()->debug_info()));
|
||||
FunctionBlockPtr header_block = GenerateBlock(std::make_shared<TraceForHeader>(block->func_graph()->debug_info()));
|
||||
MS_EXCEPTION_IF_NULL(header_block);
|
||||
// Generate the hasnext apply which is a condition
|
||||
ParameterPtr iter_param = header_block->func_graph()->add_parameter();
|
||||
CNodePtr cond_apply = GenerateCondInFor(iter_param, header_block, op_hasnext);
|
||||
// Generate the body of the for statement
|
||||
FunctionBlockPtr body_block = GenerateBlockInFor(std::make_shared<TraceForBody>(block->func_graph()->debug_info()));
|
||||
FunctionBlockPtr body_block = GenerateBlock(std::make_shared<TraceForBody>(block->func_graph()->debug_info()));
|
||||
MS_EXCEPTION_IF_NULL(body_block);
|
||||
body_block->AddPrevBlock(header_block);
|
||||
// Generate the iterator next apply
|
||||
|
@ -1323,7 +1324,7 @@ FunctionBlockPtr Parser::ParseForIter(const FunctionBlockPtr &block, const py::o
|
|||
MS_EXCEPTION_IF_NULL(after_block);
|
||||
after_block->AddPrevBlock(header_block);
|
||||
|
||||
block->Jump(header_block, iter_apply);
|
||||
block->Jump(header_block, {iter_apply});
|
||||
body_block->Mature();
|
||||
header_block->ConditionalJump(cond_apply, body_block, after_block);
|
||||
|
||||
|
@ -1332,7 +1333,7 @@ FunctionBlockPtr Parser::ParseForIter(const FunctionBlockPtr &block, const py::o
|
|||
py::object body_node = python_adapter::GetPyObjAttr(node, "body");
|
||||
FunctionBlockPtr after_body_block = ParseStatements(body_block, body_node);
|
||||
if (after_body_block->func_graph()->get_return() == nullptr) {
|
||||
after_body_block->Jump(header_block, iter2_app);
|
||||
after_body_block->Jump(header_block, {iter2_app});
|
||||
}
|
||||
|
||||
header_block->Mature();
|
||||
|
@ -1340,7 +1341,7 @@ FunctionBlockPtr Parser::ParseForIter(const FunctionBlockPtr &block, const py::o
|
|||
auto &end_block = loop_context.EndBlock();
|
||||
if (end_block) {
|
||||
// end_block exists if we encounter 'break' in loop body.
|
||||
after_block->Jump(end_block, nullptr);
|
||||
after_block->Jump(end_block, {});
|
||||
end_block->Mature();
|
||||
return end_block;
|
||||
}
|
||||
|
@ -1377,8 +1378,7 @@ FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::o
|
|||
|
||||
CNodePtr len_iter = block->func_graph()->NewCNodeInOrder({scalar_to_tensor_node, scalar_len});
|
||||
|
||||
FunctionBlockPtr header_block =
|
||||
GenerateBlockInFor(std::make_shared<TraceForHeader>(block->func_graph()->debug_info()));
|
||||
FunctionBlockPtr header_block = GenerateBlock(std::make_shared<TraceForHeader>(block->func_graph()->debug_info()));
|
||||
MS_EXCEPTION_IF_NULL(header_block);
|
||||
// Create loop variable 'i'
|
||||
ParameterPtr loop_var = header_block->func_graph()->add_parameter();
|
||||
|
@ -1388,7 +1388,7 @@ FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::o
|
|||
CNodePtr cond_node = header_block->func_graph()->NewCNodeInOrder({less_node, loop_var, len_iter});
|
||||
|
||||
// Generate the body of the for statement
|
||||
FunctionBlockPtr body_block = GenerateBlockInFor(std::make_shared<TraceForBody>(block->func_graph()->debug_info()));
|
||||
FunctionBlockPtr body_block = GenerateBlock(std::make_shared<TraceForBody>(block->func_graph()->debug_info()));
|
||||
MS_EXCEPTION_IF_NULL(body_block);
|
||||
body_block->AddPrevBlock(header_block);
|
||||
// Create 'x = xs[i]'
|
||||
|
@ -1419,7 +1419,7 @@ FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::o
|
|||
|
||||
CNodePtr zero_tensor =
|
||||
block->func_graph()->NewCNodeInOrder({scalar_to_tensor_node, NewValueNode(static_cast<int64_t>(0))});
|
||||
block->Jump(header_block, zero_tensor);
|
||||
block->Jump(header_block, {zero_tensor});
|
||||
body_block->Mature();
|
||||
|
||||
header_block->ConditionalJump(cond_node, body_block, after_block, false);
|
||||
|
@ -1429,7 +1429,7 @@ FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::o
|
|||
py::object body_node = python_adapter::GetPyObjAttr(node, "body");
|
||||
FunctionBlockPtr after_body_block = ParseStatements(body_block, body_node);
|
||||
if (after_body_block->func_graph()->get_return() == nullptr) {
|
||||
after_body_block->Jump(header_block, loop_var_inc);
|
||||
after_body_block->Jump(header_block, {loop_var_inc});
|
||||
}
|
||||
|
||||
header_block->Mature();
|
||||
|
@ -1437,7 +1437,7 @@ FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::o
|
|||
auto &end_block = loop_context.EndBlock();
|
||||
if (end_block) {
|
||||
// end_block exists if we encounter 'break' in loop body.
|
||||
after_block->Jump(end_block, nullptr);
|
||||
after_block->Jump(end_block, {});
|
||||
end_block->Mature();
|
||||
return end_block;
|
||||
}
|
||||
|
@ -1489,6 +1489,155 @@ AnfNodePtr Parser::ParseIfExp(const FunctionBlockPtr &block, const py::object &n
|
|||
return switch_app_call;
|
||||
}
|
||||
|
||||
FunctionBlockPtr Parser::ParseListCompIter(const FunctionBlockPtr &block, const py::object &node,
|
||||
const py::object &generator_node) {
|
||||
// Create a header block.
|
||||
FunctionBlockPtr top_block = GenerateBlock(std::make_shared<TraceListComp>(block->func_graph()->debug_info()));
|
||||
// Handle iter attribute.
|
||||
py::object iter_node = python_adapter::GetPyObjAttr(generator_node, "iter");
|
||||
AnfNodePtr iter_anf_node = ParseExprNode(block, iter_node);
|
||||
AnfNodePtr op_iter = top_block->MakeResolveOperation(NAMED_PRIMITIVE_ITER);
|
||||
CNodePtr iter_apply = top_block->func_graph()->NewCNodeInOrder({op_iter, iter_anf_node});
|
||||
|
||||
// Create header graph.
|
||||
FunctionBlockPtr list_header_block =
|
||||
GenerateBlock(std::make_shared<TraceForHeader>(block->func_graph()->debug_info()));
|
||||
list_header_block->AddPrevBlock(top_block);
|
||||
|
||||
// Create hasNext apply.
|
||||
AnfNodePtr op_hasnext = top_block->MakeResolveOperation(NAMED_PRIMITIVE_HASNEXT);
|
||||
ParameterPtr iter_param = list_header_block->func_graph()->add_parameter();
|
||||
constexpr auto iter_param_name = "iter";
|
||||
iter_param->set_name(iter_param_name);
|
||||
iter_param->debug_info()->set_name(iter_param_name);
|
||||
CNodePtr cond_apply = list_header_block->func_graph()->NewCNodeInOrder({op_hasnext, iter_param});
|
||||
|
||||
// Call the header graph with iter.
|
||||
ParameterPtr list_param = list_header_block->func_graph()->add_parameter();
|
||||
constexpr auto list_param_name = "list";
|
||||
list_param->set_name(list_param_name);
|
||||
list_param->debug_info()->set_name(list_param_name);
|
||||
auto empty_list = std::vector<ValuePtr>();
|
||||
AnfNodePtr empty_list_node = NewValueNode(std::make_shared<ValueList>(empty_list));
|
||||
top_block->Jump(list_header_block, {iter_apply, empty_list_node});
|
||||
|
||||
// Create body graph.
|
||||
FunctionBlockPtr list_body_block = GenerateBlock(std::make_shared<TraceForBody>(block->func_graph()->debug_info()));
|
||||
list_body_block->AddPrevBlock(list_header_block);
|
||||
AnfNodePtr op_next = top_block->MakeResolveOperation(NAMED_PRIMITIVE_NEXT);
|
||||
CNodePtr next_apply = list_body_block->func_graph()->NewCNodeInOrder({op_next, iter_param});
|
||||
AnfNodePtr op_getitem = top_block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
|
||||
CNodePtr item_apply =
|
||||
list_body_block->func_graph()->NewCNodeInOrder({op_getitem, next_apply, NewValueNode(static_cast<int64_t>(0))});
|
||||
CNodePtr new_iter =
|
||||
list_body_block->func_graph()->NewCNodeInOrder({op_getitem, next_apply, NewValueNode(static_cast<int64_t>(1))});
|
||||
|
||||
// Save the `target` in a variable.
|
||||
py::object gen_target_node = python_adapter::GetPyObjAttr(generator_node, "target");
|
||||
WriteAssignVars(list_body_block, gen_target_node, item_apply);
|
||||
|
||||
auto ifs_new_list = ParseListCompIfs(list_body_block, list_param, node, generator_node);
|
||||
list_body_block->Jump(list_header_block, {new_iter, ifs_new_list});
|
||||
|
||||
// Create after graph.
|
||||
FunctionBlockPtr list_after_block = GenerateBlock(std::make_shared<TraceForAfter>(block->func_graph()->debug_info()));
|
||||
list_after_block->AddPrevBlock(list_header_block);
|
||||
// Return the list in after graph.
|
||||
list_after_block->func_graph()->set_output(list_param);
|
||||
|
||||
// Run the branches.
|
||||
list_header_block->ConditionalJump(cond_apply, list_body_block, list_after_block);
|
||||
|
||||
top_block->Mature();
|
||||
list_header_block->Mature();
|
||||
list_body_block->Mature();
|
||||
list_after_block->Mature();
|
||||
return top_block;
|
||||
}
|
||||
|
||||
AnfNodePtr Parser::ParseListCompIfs(const FunctionBlockPtr &list_body_block, const ParameterPtr &list_param,
|
||||
const py::object &node, const py::object &generator_node) {
|
||||
// Handle ifs attribute.
|
||||
py::list ifs_node = python_adapter::GetPyObjAttr(generator_node, "ifs");
|
||||
AnfNodePtr ifs_bool_node;
|
||||
if (ifs_node.empty()) {
|
||||
ifs_bool_node = NewValueNode(true);
|
||||
} else {
|
||||
ifs_bool_node = ProcessBoolOpValueList(list_body_block, ifs_node, AST_SUB_TYPE_AND);
|
||||
}
|
||||
|
||||
// Create if-true graph.
|
||||
FunctionBlockPtr if_true_block =
|
||||
GenerateBlock(std::make_shared<TraceIfStmtTrueBranch>(list_body_block->func_graph()->debug_info()));
|
||||
if_true_block->AddPrevBlock(list_body_block);
|
||||
// Handle elt attribute in body block.
|
||||
py::object elt_obj = python_adapter::GetPyObjAttr(node, "elt");
|
||||
AnfNodePtr elt_node = ParseExprNode(list_body_block, elt_obj);
|
||||
// Append the element.
|
||||
auto list_append_op = prim::kPrimListAppend;
|
||||
auto new_list = list_body_block->func_graph()->NewCNodeInOrder({NewValueNode(list_append_op), list_param, elt_node});
|
||||
// Return new list in true branch graph.
|
||||
if_true_block->func_graph()->set_output(new_list);
|
||||
|
||||
// Create if-false graph.
|
||||
FunctionBlockPtr if_false_block =
|
||||
GenerateBlock(std::make_shared<TraceIfStmtFalseBranch>(list_body_block->func_graph()->debug_info()));
|
||||
if_false_block->AddPrevBlock(list_body_block);
|
||||
// Return original list in false branch graph.
|
||||
if_false_block->func_graph()->set_output(list_param);
|
||||
|
||||
// We don't want to create a header graph, where to get and wrap the result of Switch().
|
||||
// So just call ConditionalJump() to set Switch() as output, and reset it later, as tricky.
|
||||
list_body_block->ConditionalJump(ifs_bool_node, if_true_block, if_false_block);
|
||||
// Output is Switch() result, i.e. updated list.
|
||||
auto switch_apply_node = list_body_block->func_graph()->output();
|
||||
auto ifs_new_list = switch_apply_node;
|
||||
// Since we call ConditionalJump() above, to reset the Return as null before call Jump().
|
||||
list_body_block->func_graph()->set_return(nullptr);
|
||||
if_true_block->Mature();
|
||||
if_false_block->Mature();
|
||||
return ifs_new_list;
|
||||
}
|
||||
|
||||
// A ListComp contains: `elt` and `generators`.
|
||||
// `generators` contains: `target`, `iter` and `ifs`.
|
||||
// For example:
|
||||
// [x * x for x in range(0, 10) if x % 2 == 0]
|
||||
// It is compiled to be following statement:
|
||||
// list = []
|
||||
// for x in range(0, 10):
|
||||
// if x % 2 == 0:
|
||||
// list.append(x * x)
|
||||
// return list
|
||||
AnfNodePtr Parser::ParseListComp(const FunctionBlockPtr &block, const py::object &node) {
|
||||
MS_LOG(DEBUG) << "Process ast ListComp";
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
|
||||
// Handle generators attribute.
|
||||
py::list generators_node = python_adapter::GetPyObjAttr(node, "generators");
|
||||
if (generators_node.size() != 1) {
|
||||
MS_EXCEPTION(TypeError) << "The `generators` supports one `comprehension` in ListComp/GeneratorExp, but got "
|
||||
<< generators_node.size() << " comprehensions.";
|
||||
}
|
||||
py::object generator_node = generators_node[0];
|
||||
auto generator_node_type = ast_->GetNodeType(generator_node);
|
||||
auto generator_node_name = generator_node_type->node_name();
|
||||
constexpr auto comprehension_name = "comprehension";
|
||||
if (generator_node_name != comprehension_name) {
|
||||
MS_LOG(EXCEPTION) << "Generator node name should be " << comprehension_name << ", but got " << generator_node_name;
|
||||
}
|
||||
|
||||
// Parse ListComp's `iter` and add `elt` in it.
|
||||
auto top_block = ParseListCompIter(block, node, generator_node);
|
||||
|
||||
// Call the top graph and return the list.
|
||||
auto call_function_anf_node = NewValueNode(top_block->func_graph());
|
||||
std::vector<AnfNodePtr> func_call_nodes;
|
||||
func_call_nodes.push_back(call_function_anf_node);
|
||||
AnfNodePtr output = block->func_graph()->NewCNodeInOrder(func_call_nodes);
|
||||
return output;
|
||||
}
|
||||
|
||||
void Parser::HandleAssignName(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node) {
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
MS_EXCEPTION_IF_NULL(assigned_node);
|
||||
|
@ -1644,7 +1793,7 @@ FunctionBlockPtr Parser::ParseBreak(const FunctionBlockPtr &block, const py::obj
|
|||
loop.end = MakeFunctionBlock(*this);
|
||||
}
|
||||
// Jump to the end_block.
|
||||
block->Jump(loop.end, nullptr);
|
||||
block->Jump(loop.end, {});
|
||||
return block;
|
||||
}
|
||||
|
||||
|
@ -1655,7 +1804,11 @@ FunctionBlockPtr Parser::ParseContinue(const FunctionBlockPtr &block, const py::
|
|||
}
|
||||
// Jump to the header of the loop with iterator called.
|
||||
Loop &loop = loops_.top();
|
||||
block->Jump(loop.header, loop.iterator);
|
||||
std::vector<AnfNodePtr> args;
|
||||
if (loop.iterator != nullptr) {
|
||||
args.emplace_back(loop.iterator);
|
||||
}
|
||||
block->Jump(loop.header, args);
|
||||
return block;
|
||||
}
|
||||
|
||||
|
|
|
@ -38,19 +38,19 @@ namespace parse {
|
|||
// Parse status define
|
||||
enum ParseStatusCode : int64_t {
|
||||
PARSE_SUCCESS = 0,
|
||||
PARSE_FUNCTION_IS_NULL, // python function is null
|
||||
PARSE_PARAMETER_INVALID, // parameter is invalid
|
||||
PARSE_NO_RETURN, // function no return node
|
||||
PARSE_NODE_TYPE_NO_MATCH, // ast node type is error
|
||||
PARSE_NODE_TYPE_UNKNOWN, // node type is unknown
|
||||
PARSE_NODE_METHOD_UNSUPPORTED, // no method to parse the node
|
||||
PARSE_DONT_RESOLVE_SYMBOL, // can't resolve the string
|
||||
PARSE_NOT_SUPPORTED_COMPARE_EXPR, // the comparison is not supported
|
||||
PARSE_FUNCTION_IS_NULL, // Python function is null
|
||||
PARSE_PARAMETER_INVALID, // Parameter is invalid
|
||||
PARSE_NO_RETURN, // Function no return node
|
||||
PARSE_NODE_TYPE_NO_MATCH, // Ast node type is error
|
||||
PARSE_NODE_TYPE_UNKNOWN, // Node type is unknown
|
||||
PARSE_NODE_METHOD_UNSUPPORTED, // No method to parse the node
|
||||
PARSE_DONT_RESOLVE_SYMBOL, // Can't resolve the string
|
||||
PARSE_NOT_SUPPORTED_COMPARE_EXPR, // The comparison is not supported
|
||||
PARSE_FAILURE = 0xFF
|
||||
};
|
||||
|
||||
// max loop count of for statement, when loop count is less then this value, the for loop will be unrolled, otherwise it
|
||||
// will be sunk(i.e. not unrolled)
|
||||
// Max loop count of for statement, when loop count is less then this value, the for loop will be unrolled, otherwise it
|
||||
// will be sunk(i.e. not unrolled)
|
||||
// NOTE: Since when the for loop was unrolled, it depends backend operators `tuple_getitem` and `scalar_add` which were
|
||||
// not implemented, so here set MAX_FOR_LOOP_COUNT to int64_t max limit to override default value `600`. This will make
|
||||
// the for loop will always be unrolled, but don't worry about the memory were exhausted, an exception will be raised
|
||||
|
@ -97,7 +97,7 @@ class Parser {
|
|||
FuncGraphPtr func_graph() const { return func_graph_; }
|
||||
ParseStatusCode errcode() const { return errcode_; }
|
||||
std::shared_ptr<ParseAst> ast() const { return ast_; }
|
||||
// get location info from the ast node
|
||||
// Get location info from the ast node
|
||||
LocationPtr GetLocation(const py::object &node) const;
|
||||
static void InitParserEnvironment(const py::object &obj);
|
||||
static void CleanParserResource();
|
||||
|
@ -105,114 +105,118 @@ class Parser {
|
|||
static void UpdateTopFuncGraph(const FuncGraphPtr &func_graph);
|
||||
|
||||
private:
|
||||
// process the stmt node method list
|
||||
// Process the stmt node method list
|
||||
FunctionBlockPtr ParseReturn(const FunctionBlockPtr &block, const py::object &node);
|
||||
// parse expression
|
||||
// Parse expression
|
||||
FunctionBlockPtr ParseExpr(const FunctionBlockPtr &block, const py::object &node);
|
||||
// process a if statement
|
||||
// Process a if statement
|
||||
FunctionBlockPtr ParseIf(const FunctionBlockPtr &block, const py::object &node);
|
||||
// process a while statement
|
||||
// Process a while statement
|
||||
FunctionBlockPtr ParseWhile(const FunctionBlockPtr &block, const py::object &node);
|
||||
// process a for statement
|
||||
// Process a for statement
|
||||
FunctionBlockPtr ParseFor(const FunctionBlockPtr &block, const py::object &node);
|
||||
FunctionBlockPtr ParseForIter(const FunctionBlockPtr &block, const py::object &node);
|
||||
FunctionBlockPtr ParseForLoop(const FunctionBlockPtr &block, const py::object &node);
|
||||
// process a function def statement
|
||||
// Process a function def statement
|
||||
FunctionBlockPtr ParseFunctionDef(const FunctionBlockPtr &block, const py::object &node);
|
||||
// process a augment assign
|
||||
// Process a augment assign
|
||||
FunctionBlockPtr ParseAugAssign(const FunctionBlockPtr &block, const py::object &node);
|
||||
// process a global declaration
|
||||
// Process a global declaration
|
||||
FunctionBlockPtr ParseGlobal(const FunctionBlockPtr &block, const py::object &node);
|
||||
// process assign statement
|
||||
// Process assign statement
|
||||
FunctionBlockPtr ParseAssign(const FunctionBlockPtr &block, const py::object &node);
|
||||
// process break statement
|
||||
// Process break statement
|
||||
FunctionBlockPtr ParseBreak(const FunctionBlockPtr &block, const py::object &node);
|
||||
// process continue statement
|
||||
// Process continue statement
|
||||
FunctionBlockPtr ParseContinue(const FunctionBlockPtr &block, const py::object &node);
|
||||
// process pass statement
|
||||
// Process pass statement
|
||||
FunctionBlockPtr ParsePass(const FunctionBlockPtr &block, const py::object &node);
|
||||
// process the expr and slice node method list
|
||||
|
||||
// Process the expr and slice node method list
|
||||
AnfNodePtr ParseBinOp(const FunctionBlockPtr &block, const py::object &node);
|
||||
// process a variable name
|
||||
// Process a variable name
|
||||
AnfNodePtr ParseName(const FunctionBlockPtr &block, const py::object &node);
|
||||
// process NoneType
|
||||
// Process NoneType
|
||||
AnfNodePtr ParseNone(const FunctionBlockPtr &block, const py::object &node);
|
||||
// process Ellipsis
|
||||
// Process Ellipsis
|
||||
AnfNodePtr ParseEllipsis(const FunctionBlockPtr &block, const py::object &node);
|
||||
// process a integer or float number
|
||||
// Process a integer or float number
|
||||
AnfNodePtr ParseNum(const FunctionBlockPtr &block, const py::object &node);
|
||||
// process a string variable
|
||||
// Process a string variable
|
||||
AnfNodePtr ParseStr(const FunctionBlockPtr &block, const py::object &node);
|
||||
// process a Constant
|
||||
// Process a Constant
|
||||
AnfNodePtr ParseConstant(const FunctionBlockPtr &block, const py::object &node);
|
||||
// process a name
|
||||
// Process a name
|
||||
AnfNodePtr ParseNameConstant(const FunctionBlockPtr &block, const py::object &node);
|
||||
// process a function call
|
||||
// Process a function call
|
||||
AnfNodePtr ParseCall(const FunctionBlockPtr &block, const py::object &node);
|
||||
// process function 'super'
|
||||
// Process function 'super'
|
||||
AnfNodePtr ParseSuper(const FunctionBlockPtr &block, const py::list &args);
|
||||
// process the if expression
|
||||
// Process the if expression
|
||||
AnfNodePtr ParseIfExp(const FunctionBlockPtr &block, const py::object &node);
|
||||
// process class type define
|
||||
// Process class type define
|
||||
AnfNodePtr ParseAttribute(const FunctionBlockPtr &block, const py::object &node);
|
||||
// process a compare expression
|
||||
// Process a compare expression
|
||||
AnfNodePtr ParseCompare(const FunctionBlockPtr &block, const py::object &node);
|
||||
// process a bool operation
|
||||
// Process a bool operation
|
||||
AnfNodePtr ParseBoolOp(const FunctionBlockPtr &block, const py::object &node);
|
||||
// process a lambda operation
|
||||
// Process a lambda operation
|
||||
AnfNodePtr ParseLambda(const FunctionBlockPtr &block, const py::object &node);
|
||||
// process a tuple
|
||||
// Process a tuple
|
||||
AnfNodePtr ParseTuple(const FunctionBlockPtr &block, const py::object &node);
|
||||
// process a tuple
|
||||
// Process a tuple
|
||||
AnfNodePtr ParseList(const FunctionBlockPtr &block, const py::object &node);
|
||||
// process a tuple
|
||||
// Process a tuple
|
||||
AnfNodePtr ParseSubscript(const FunctionBlockPtr &block, const py::object &node);
|
||||
// process a slice
|
||||
// Process a slice
|
||||
AnfNodePtr ParseSlice(const FunctionBlockPtr &block, const py::object &node);
|
||||
|
||||
// process a extslice
|
||||
// Process a extslice
|
||||
AnfNodePtr ParseExtSlice(const FunctionBlockPtr &block, const py::object &node);
|
||||
|
||||
// process a tuple
|
||||
// Process a tuple
|
||||
AnfNodePtr ParseIndex(const FunctionBlockPtr &block, const py::object &node);
|
||||
|
||||
// process a unaryop
|
||||
// Process a unaryop
|
||||
AnfNodePtr ParseUnaryOp(const FunctionBlockPtr &block, const py::object &node);
|
||||
|
||||
// process a dict ast node expression
|
||||
// Process a dict ast node expression
|
||||
AnfNodePtr ParseDict(const FunctionBlockPtr &block, const py::object &node);
|
||||
// generate argument nodes for ast function node
|
||||
// Process ListComp expression
|
||||
AnfNodePtr ParseListComp(const FunctionBlockPtr &block, const py::object &node);
|
||||
FunctionBlockPtr ParseListCompIter(const FunctionBlockPtr &block, const py::object &node,
|
||||
const py::object &generator_node);
|
||||
AnfNodePtr ParseListCompIfs(const FunctionBlockPtr &list_body_block, const ParameterPtr &list_param,
|
||||
const py::object &node, const py::object &generator_node);
|
||||
|
||||
// Generate argument nodes for ast function node
|
||||
void GenerateArgsNodeForFunction(const FunctionBlockPtr &block, const py::object &function_node);
|
||||
// generate argument default value for ast function node
|
||||
// Generate argument default value for ast function node
|
||||
void GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block, const py::object &function_node);
|
||||
// parse ast function node
|
||||
// Parse ast function node
|
||||
FunctionBlockPtr ParseFunction(const py::object &function_node, const FunctionBlockPtr &block = nullptr);
|
||||
// parse ast statements
|
||||
// Parse ast statements
|
||||
FunctionBlockPtr ParseStatements(FunctionBlockPtr block, const py::object &stmt_node);
|
||||
// parse one ast statement node
|
||||
// Parse one ast statement node
|
||||
FunctionBlockPtr ParseStatement(const FunctionBlockPtr &block, const py::object &node);
|
||||
// parse an ast expression node
|
||||
// Parse an ast expression node
|
||||
AnfNodePtr ParseExprNode(const FunctionBlockPtr &block, const py::object &node);
|
||||
|
||||
void MakeConditionBlocks(const FunctionBlockPtr &block, const FunctionBlockPtr &trueBlock,
|
||||
const FunctionBlockPtr &falseBlock);
|
||||
void RemoveUnnecessaryPhis();
|
||||
// write a new var
|
||||
// Write a new var
|
||||
void WriteAssignVars(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &value_node);
|
||||
|
||||
// assign value to single variable name
|
||||
// Assign value to single variable name
|
||||
void HandleAssignName(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node);
|
||||
|
||||
// assign value to tuple
|
||||
// Assign value to tuple
|
||||
void HandleAssignTuple(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node);
|
||||
|
||||
// assign value to class member
|
||||
// Assign value to class member
|
||||
void HandleAssignClassMember(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node);
|
||||
|
||||
// assign value to subscript
|
||||
// Assign value to subscript
|
||||
void HandleAssignSubscript(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node);
|
||||
|
||||
// process a bool operation value list
|
||||
// Process a bool operation value list
|
||||
AnfNodePtr ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, AstSubType mode);
|
||||
|
||||
CNodePtr GenerateIteratorInFor(const FunctionBlockPtr &block, const pybind11::object &node,
|
||||
|
@ -221,7 +225,7 @@ class Parser {
|
|||
CNodePtr GenerateCondInFor(const ParameterPtr &iter_param, const FunctionBlockPtr &header_block,
|
||||
const AnfNodePtr &op_hasnext);
|
||||
|
||||
FunctionBlockPtr GenerateBlockInFor(const TraceInfoPtr &trace_info);
|
||||
FunctionBlockPtr GenerateBlock(const TraceInfoPtr &trace_info);
|
||||
|
||||
bool ParseKeywordsInCall(const FunctionBlockPtr &block, const py::object &node,
|
||||
std::vector<AnfNodePtr> *packed_arguments);
|
||||
|
@ -249,27 +253,27 @@ class Parser {
|
|||
func_block_list_.push_back(block);
|
||||
return block;
|
||||
}
|
||||
// return a make tuple for input elements list
|
||||
// Return a make tuple for input elements list
|
||||
AnfNodePtr GenerateMakeTuple(const FunctionBlockPtr &block, const std::vector<AnfNodePtr> &element_nodes);
|
||||
int64_t GetForTransToWhileLoop();
|
||||
|
||||
// shared_ptr will be hold by GraphManager, so just hold a weak ref here.
|
||||
// The shared_ptr will be hold by GraphManager, so just hold a weak ref here.
|
||||
static FuncGraphWeakPtr top_func_graph_;
|
||||
// Python function id, used to indicate whether two CNodes come from the same Python function
|
||||
const std::shared_ptr<ParseAst> &ast_;
|
||||
FuncGraphPtr func_graph_;
|
||||
// error code setwhen parsing ast tree
|
||||
// Error code setwhen parsing ast tree
|
||||
ParseStatusCode errcode_;
|
||||
|
||||
// hold all reference for FunctionBlock in this round of parsing,
|
||||
// Hold all reference for FunctionBlock in this round of parsing,
|
||||
// so in FunctionBlock class we can use FunctionBlock* in member
|
||||
// pre_blocks_ and jumps_ to break reference cycle.
|
||||
std::vector<FunctionBlockPtr> func_block_list_;
|
||||
using pStmtFunc = FunctionBlockPtr (Parser::*)(const FunctionBlockPtr &block, const py::object &node);
|
||||
using pExprFunc = AnfNodePtr (Parser::*)(const FunctionBlockPtr &block, const py::object &node);
|
||||
// define the function map to parse ast Statement
|
||||
// Define the function map to parse ast Statement
|
||||
std::map<std::string, pStmtFunc> stmt_method_map_;
|
||||
// define the function map to parse ast expression
|
||||
// Define the function map to parse ast expression
|
||||
std::map<std::string, pExprFunc> expr_method_map_;
|
||||
// Save current loops to support 'continue', 'break' statement.
|
||||
std::stack<Loop> loops_;
|
||||
|
@ -350,10 +354,10 @@ class ParseAst {
|
|||
bool IsClassMember(const py::object &node);
|
||||
|
||||
private:
|
||||
// save obj,eg: class instance or function
|
||||
// Save obj,eg: class instance or function
|
||||
py::object obj_;
|
||||
|
||||
// function or class method.
|
||||
// Function or class method.
|
||||
py::function function_;
|
||||
|
||||
py::object ast_tree_;
|
||||
|
@ -369,7 +373,7 @@ class ParseAst {
|
|||
int64_t function_line_offset_;
|
||||
};
|
||||
|
||||
// update the graph flags
|
||||
// Update the graph flags
|
||||
bool UpdateFuncGraphFlags(const py::object &obj, const FuncGraphPtr &func_graph);
|
||||
|
||||
AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m);
|
||||
|
|
|
@ -271,10 +271,14 @@ const AbstractBasePtr AbstractSequeue::operator[](const std::size_t &dim) const
|
|||
|
||||
std::string AbstractSequeue::ToString() const {
|
||||
std::ostringstream buffer;
|
||||
int64_t i = 0;
|
||||
size_t i = 0;
|
||||
size_t size = elements_.size();
|
||||
for (const auto &ele : elements_) {
|
||||
MS_EXCEPTION_IF_NULL(ele);
|
||||
buffer << "element[" << i << "]: " << ele->ToString() << ",";
|
||||
buffer << "element[" << i << "]: " << ele->ToString();
|
||||
if (i < size - 1) {
|
||||
buffer << ", ";
|
||||
}
|
||||
i++;
|
||||
}
|
||||
return buffer.str();
|
||||
|
|
|
@ -318,8 +318,11 @@ AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePt
|
|||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
AbstractListPtr list = CheckArg<AbstractList>(op_name, args_spec_list, 0);
|
||||
(void)AbstractJoin(list->elements());
|
||||
return list;
|
||||
AbstractBasePtr item = dyn_cast<AbstractBase>(args_spec_list[1]);
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
auto new_list = AbstractBasePtrList(list->elements());
|
||||
new_list.emplace_back(item);
|
||||
return std::make_shared<AbstractList>(new_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplTupleLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
|
|
@ -430,6 +430,14 @@ class TraceOpt : public TraceInfo {
|
|||
~TraceOpt() override = default;
|
||||
TraceInfoPtr clone() override { return std::make_shared<TraceOpt>(*shared_from_base<TraceOpt>()); }
|
||||
};
|
||||
|
||||
class TraceListComp : public TraceInfo {
|
||||
public:
|
||||
explicit TraceListComp(const DebugInfoPtr &info) : TraceInfo(info, "ListComp", "G-") {}
|
||||
MS_DECLARE_PARENT(TraceListComp, TraceInfo);
|
||||
~TraceListComp() override = default;
|
||||
TraceInfoPtr clone() override { return std::make_shared<TraceListComp>(*shared_from_base<TraceListComp>()); }
|
||||
};
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_UTILS_TRACE_INFO_H_
|
||||
|
|
|
@ -0,0 +1,67 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test ListComp and GeneratorExp """
|
||||
import pytest
|
||||
|
||||
from mindspore import context, ms_function
|
||||
|
||||
@ms_function
|
||||
def get_list_comp_1():
|
||||
l = [x for x in range(1, 6)]
|
||||
return l
|
||||
|
||||
@ms_function
|
||||
def get_list_comp_2():
|
||||
l = [x * x for x in range(1, 6)]
|
||||
return l
|
||||
|
||||
@ms_function
|
||||
def get_list_comp_3():
|
||||
l = [x * x for x in range(1, 11) if x % 2 == 0]
|
||||
return l
|
||||
|
||||
@ms_function
|
||||
def get_list_comp_4():
|
||||
l = [x * x for x in range(1, 11) if x > 5 if x % 2 == 0]
|
||||
return l
|
||||
|
||||
@ms_function
|
||||
def get_list_comp_5():
|
||||
# Create a ListComp with multiple comprehension.
|
||||
# Not supported.
|
||||
l = [y for x in ((1, 2), (3, 4), (5, 6)) for y in x] # [1, 2, 3, 4, 5, 6]
|
||||
return l
|
||||
|
||||
@ms_function
|
||||
def get_generator_exp_1():
|
||||
t = (x for x in range(1, 6))
|
||||
return t
|
||||
|
||||
@ms_function
|
||||
def get_generator_exp_2():
|
||||
t = (x * x for x in range(1, 11) if x > 5 if x % 2 == 0)
|
||||
return t
|
||||
|
||||
def test_list_comp():
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
assert get_list_comp_1() == (1, 2, 3, 4, 5)
|
||||
assert get_list_comp_2() == (1, 4, 9, 16, 25)
|
||||
assert get_list_comp_3() == (4, 16, 36, 64, 100)
|
||||
assert get_list_comp_4() == (36, 64, 100)
|
||||
with pytest.raises(TypeError) as ex:
|
||||
get_list_comp_5()
|
||||
assert "The `generators` supports one `comprehension` in ListComp/GeneratorExp" in str(ex.value)
|
||||
assert get_generator_exp_1() == (1, 2, 3, 4, 5)
|
||||
assert get_generator_exp_2() == (36, 64, 100)
|
Loading…
Reference in New Issue