code self check

This commit is contained in:
chenfei 2021-07-28 10:23:36 +08:00
parent e187cfc889
commit 8e5a250c21
22 changed files with 485 additions and 234 deletions

View File

@ -153,6 +153,7 @@ abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraph
FuncGraphPtr ProgramSpecialize(const ResourcePtr &res, const FuncGraphPtr &func_graph,
const abstract::AnalysisContextPtr &context) {
MS_EXCEPTION_IF_NULL(res);
MS_LOG(DEBUG) << "ProgramSpecialize start";
abstract::ProgramSpecializer spc(res->engine());
FuncGraphPtr result = spc.Run(func_graph, context);
@ -165,6 +166,7 @@ FuncGraphPtr ProgramSpecialize(const ResourcePtr &res, const FuncGraphPtr &func_
FuncGraphPtr Renormalize(const ResourcePtr &res, const FuncGraphPtr &func_graph,
const abstract::AbstractBasePtrList &args_spec) {
MS_EXCEPTION_IF_NULL(res);
MS_LOG(DEBUG) << "Renormalize start";
#ifdef ENABLE_PROFILE
double t1 = GetTime();
@ -250,6 +252,7 @@ void CheckRootInputShapeAndType(const ResourcePtr &res, const FuncGraphPtr &load
}
bool ParseAction(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res);
if (!res->input()) {
MS_LOG(EXCEPTION) << "Parse error";
}
@ -293,8 +296,8 @@ bool ParseAction(const ResourcePtr &res) {
// graph1(x){base_graph(x, fv1, fv2)}, graph1(x){base_graph(x, fv3, fv4)}, base_graph(x, fv...){xxx,xxx}
// all obj_map's graph shared base_graph
bool CombineLikeGraphs(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res);
auto &obj_map = parse::data_converter::GetObjGraphs();
for (auto it : obj_map) {
auto &graphs = it.second;
MS_LOG(DEBUG) << "Start combine like graph:" << it.first << ", size:" << graphs.size();
@ -313,6 +316,7 @@ bool CombineLikeGraphs(const ResourcePtr &res) {
for (auto &fv : fg->paramter_obj_nodes()) {
TraceGuard guard(std::make_shared<TraceCombileLikeGraphs>(fv->debug_info()));
auto param = base_graph->add_parameter();
MS_EXCEPTION_IF_NULL(res->manager());
auto &node_users = res->manager()->node_users()[fv];
for (auto &n : node_users) {
// If the user is not in this graph, no need to change.
@ -321,6 +325,7 @@ bool CombineLikeGraphs(const ResourcePtr &res) {
continue;
}
auto repl_n = cloned->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(repl_n);
repl_n->set_input(IntToSize(n.second), param);
}
}
@ -346,6 +351,7 @@ bool CombineLikeGraphs(const ResourcePtr &res) {
}
bool SymbolResolveAction(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res);
if (res->manager() == nullptr) {
MS_LOG(EXCEPTION) << "SymbolResolve error, manager is null";
}
@ -367,6 +373,7 @@ bool SymbolResolveAction(const ResourcePtr &res) {
}
bool AutoMonadAction(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res);
if (res->manager() == nullptr) {
MS_LOG(EXCEPTION) << "Auto-Monad failed, manager is null";
}
@ -379,6 +386,7 @@ bool AutoMonadAction(const ResourcePtr &res) {
}
bool OrderEnforceAction(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res);
if (res->manager() == nullptr) {
MS_LOG(EXCEPTION) << "Order-Enforce error, manager is null";
}
@ -391,6 +399,7 @@ bool OrderEnforceAction(const ResourcePtr &res) {
}
bool RemoveRandomOpMonadAction(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res);
if (res->manager() == nullptr) {
MS_LOG(EXCEPTION) << "Remove-Random-Op-Monad error, manager is null";
}
@ -403,6 +412,7 @@ bool RemoveRandomOpMonadAction(const ResourcePtr &res) {
}
bool InferenceOptPrepareAction(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res);
if (res->manager() == nullptr) {
MS_LOG(EXCEPTION) << "InferenceOptPrepare error, manager is null.";
}
@ -413,6 +423,7 @@ bool InferenceOptPrepareAction(const ResourcePtr &res) {
}
bool AbstractSpecializeAction(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res);
if (res->func_graph() == nullptr) {
MS_LOG(EXCEPTION) << "AbstractSpecialize error";
}
@ -428,8 +439,10 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
// get the hyper parameter
for (const auto &param : func_graph->parameters()) {
auto param_node = std::static_pointer_cast<Parameter>(param);
MS_EXCEPTION_IF_NULL(param_node);
if (param_node->has_default()) {
auto value = param_node->default_param();
MS_EXCEPTION_IF_NULL(value);
auto abs_value = value->ToAbstract()->cast<abstract::AbstractTensorPtr>();
auto ref_key = std::make_shared<RefKey>(param_node->name());
auto abs_ref_key = ref_key->ToAbstract();
@ -466,6 +479,7 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
}
bool OptimizeAction(const ResourcePtr &res, const std::vector<PassItem> &passes) {
MS_EXCEPTION_IF_NULL(res);
size_t counter = 0;
for (auto &pass : passes) {
WITH(MsProfile::GetProfile()->Step(pass.first))[&pass, &res, &counter]() {
@ -513,16 +527,8 @@ bool VmOptimizeAction(const ResourcePtr &res) {
return OptimizeAction(res, kVmPasses);
}
bool PynativeOptimizeAction(const ResourcePtr &resource) {
WITH(MsProfile::GetProfile())[&resource]() { (void)OptimizeAction(resource, kPynativePasses); };
#ifdef ENABLE_PROFILE
MsProfile::Print();
MsProfile::Reset();
#endif
return true;
}
bool PynativeElimOpt(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res);
if (res->manager() == nullptr) {
MS_LOG(EXCEPTION) << "PynativeElimOpt error, manager is null.";
}
@ -564,6 +570,7 @@ bool CheckGraphOutputConstOrParameter(const FuncGraphPtr &func_graph) {
}
bool TaskEmitAction(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res);
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode &&
CheckGraphOutputConstOrParameter(res->func_graph())) {
return true;
@ -613,6 +620,7 @@ bool TaskEmitAction(const ResourcePtr &res) {
}
bool ExecuteAction(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res);
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode &&
CheckGraphOutputConstOrParameter(res->func_graph())) {
return true;
@ -673,6 +681,7 @@ bool StartFLWorkerAction(const ResourcePtr &) {
}
bool StartPSServerAction(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res);
FuncGraphPtr func_graph = res->func_graph();
auto &ps = ps::ParameterServer::GetInstance();
ps.Run(func_graph);
@ -680,6 +689,7 @@ bool StartPSServerAction(const ResourcePtr &res) {
}
bool StartServerAction(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res);
FuncGraphPtr func_graph = res->func_graph();
const std::string &server_mode_ = ps::PSContext::instance()->server_mode();
uint32_t worker_num = ps::PSContext::instance()->initial_worker_num();
@ -735,6 +745,8 @@ bool StartPSSchedulerAction(const ResourcePtr &) {
// Here we temporarily avoid the problem by skipping valuenode merging used by parallel related primitive,
// the final solution will be proposed later as a parallel feature.
bool KeepValueNodeDuplication(const AnfNodePtr &value_node, const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res);
MS_EXCEPTION_IF_NULL(res->manager());
auto &node_users = res->manager()->node_users();
auto &users = node_users[value_node];
auto used_by_keep_value_prim =
@ -747,6 +759,7 @@ bool KeepValueNodeDuplication(const AnfNodePtr &value_node, const ResourcePtr &r
auto prim_node = cnode->input(0);
if (IsValueNode<Primitive>(prim_node)) {
auto prim = GetValue<PrimitivePtr>(prim_node->cast<ValueNodePtr>()->value());
MS_EXCEPTION_IF_NULL(prim);
// value_node is referenced by some parallel primitive
return prim->HasAttr("keep_value_node_input");
}
@ -756,10 +769,11 @@ bool KeepValueNodeDuplication(const AnfNodePtr &value_node, const ResourcePtr &r
}
bool RemoveValueNodeDuplicationsAction(const ResourcePtr &res) {
if (res->func_graph() == nullptr) {
MS_EXCEPTION_IF_NULL(res);
FuncGraphPtr func_graph = res->func_graph();
if (func_graph == nullptr) {
MS_LOG(EXCEPTION) << "Remove value node duplications error.";
}
FuncGraphPtr func_graph = res->func_graph();
auto manager = res->manager();
// Remove duplicated value nodes, due to replace operation, can't use reference.
auto value_nodes = func_graph->value_nodes();
@ -796,8 +810,8 @@ bool OptActionVmPyStub(const ResourcePtr &res) {
if (ActionPyStub(res, opt::python_pass::Phase::OPT)) {
if (opt::python_pass::PyPassManager::GetInstance()->ShouldRenorm()) {
// Renomalize
MS_EXCEPTION_IF_NULL(res->func_graph());
FuncGraphPtr func_graph = res->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
abstract::AbstractBasePtrList args_spec;
auto parameters = func_graph->parameters();
(void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec),
@ -817,8 +831,8 @@ bool OptActionGePyStub(const ResourcePtr &res) {
if (ActionPyStub(res, opt::python_pass::Phase::OPT)) {
if (opt::python_pass::PyPassManager::GetInstance()->ShouldRenorm()) {
// Renomalize
MS_EXCEPTION_IF_NULL(res->func_graph());
FuncGraphPtr func_graph = res->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
abstract::AbstractBasePtrList args_spec;
auto parameters = func_graph->parameters();
(void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec),

View File

@ -36,7 +36,6 @@ bool AutoMonadAction(const ResourcePtr &res);
bool AbstractSpecializeAction(const ResourcePtr &res);
bool GeOptimizeAction(const ResourcePtr &res);
bool VmOptimizeAction(const ResourcePtr &res);
bool PynativeOptimizeAction(const ResourcePtr &res);
bool PynativeElimOpt(const ResourcePtr &res);
bool TaskEmitAction(const ResourcePtr &res);
bool ExecuteAction(const ResourcePtr &res);

View File

@ -503,7 +503,7 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature
}
ValuePtr converted = nullptr;
bool matched = false;
auto &&converters = GetDataConverters();
auto converters = GetDataConverters();
for (auto &converter : converters) {
if (converter->Matched(obj)) {
converted = converter->ConvertPyObject(obj, use_signature, dtype);

View File

@ -63,7 +63,9 @@ static bool CanBeIsolatedNode(const std::string &var_name, const AnfNodePtr &nod
// Write variable records the variable name to corresponding node
void FunctionBlock::WriteVariable(const std::string &var_name, const AnfNodePtr &node) {
MS_LOG(DEBUG) << func_graph_->ToString() << " write var " << var_name << " with node " << node->DebugString();
MS_EXCEPTION_IF_NULL(node);
MS_LOG(DEBUG) << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " write var " << var_name << " with node "
<< node->DebugString();
auto [iter, is_new_name] = vars_.emplace(var_name, std::make_pair(node, false));
if (!is_new_name) {
// If a cnode variable with same name already existed but not used,
@ -76,9 +78,10 @@ void FunctionBlock::WriteVariable(const std::string &var_name, const AnfNodePtr
auto hidden_node = iter->second.first;
auto is_isolated = CanBeIsolatedNode(var_name, hidden_node);
if (!is_used && is_isolated) {
MS_EXCEPTION_IF_NULL(hidden_node);
MS_LOG(INFO) << "Isolated node found(Hidden), hidden_node: " << hidden_node->DebugString(2) << " is hidden by "
<< node->DebugString(2) << " with the same name, var_name: " << var_name << ", block: " << this
<< "/" << (func_graph() ? func_graph()->ToString() : "FG(Null)")
<< "/" << (func_graph_ ? func_graph_->ToString() : "FG(Null)")
<< ", Line: " << trace::GetDebugInfo(hidden_node->debug_info(), "", kSourceLineTipDiscard);
AddIsolatedNode(hidden_node);
}
@ -124,7 +127,8 @@ AnfNodePtr FunctionBlock::ReadVariable(const std::string &var) {
debug_info->set_name(var);
TraceGuard guard(std::make_shared<TracePhi>(debug_info));
ParameterPtr phi_param = std::make_shared<Parameter>(func_graph());
MS_LOG(DEBUG) << func_graph_->ToString() << " generate phi node " << phi_param->ToString() << " for " << var;
MS_LOG(DEBUG) << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " generate phi node "
<< phi_param->ToString() << " for " << var;
func_graph()->add_parameter(phi_param);
phi_nodes_[phi_param] = var;
WriteVariable(var, phi_param);
@ -150,8 +154,9 @@ AnfNodePtr FunctionBlock::MakeResolveAstOp(const py::object &op) {
// Resolve class member, two possible: method, member variable
AnfNodePtr FunctionBlock::MakeResolveClassMember(const std::string &attr) {
py::object namespace_var =
parser_.ast()->CallParseModFunction(PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, parser_.ast()->obj());
auto ast = parser_.ast();
MS_EXCEPTION_IF_NULL(ast);
py::object namespace_var = ast->CallParseModFunction(PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, ast->obj());
NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var);
SymbolPtr symbol = std::make_shared<Symbol>(attr);
return MakeResolve(name_space, symbol);
@ -168,8 +173,13 @@ AnfNodePtr FunctionBlock::MakeResolveSymbol(const std::string &value) {
auto bits_str = value.substr(start);
return MakeResolveClassMember(bits_str);
}
py::tuple namespace_info = parser_.ast()->CallParserObjMethod(PYTHON_PARSE_GET_NAMESPACE_SYMBOL, value);
auto ast = parser_.ast();
MS_EXCEPTION_IF_NULL(ast);
py::tuple namespace_info = ast->CallParserObjMethod(PYTHON_PARSE_GET_NAMESPACE_SYMBOL, value);
const size_t namespace_info_size = 2;
if (namespace_info.size() < namespace_info_size) {
MS_EXCEPTION(NameError) << "namespace_info is less than 2";
}
// If namespace is None, the symbol is an undefined name or an unsupported builtin function.
if (namespace_info[0].is_none()) {
// If the size of namespace_var is greater than or equal to 3, the error information is stored in namespace_var[2].
@ -179,17 +189,15 @@ AnfNodePtr FunctionBlock::MakeResolveSymbol(const std::string &value) {
// If the size of namespace_var is less than 3, the default error information is used.
MS_EXCEPTION(NameError) << "The name \'" << value << "\' is not defined.";
}
if (namespace_info.size() < namespace_info_size) {
MS_EXCEPTION(NameError) << "namespace_info is less than 2";
}
NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_SYMBOL_STR, namespace_info[0]);
SymbolPtr symbol = std::make_shared<Symbol>(namespace_info[1].cast<std::string>());
return MakeResolve(name_space, symbol);
}
AnfNodePtr FunctionBlock::MakeResolveOperation(const std::string &value) {
py::tuple namespace_var = parser_.ast()->CallParseModFunction(PYTHON_PARSE_GET_OPERATION_NAMESPACE_SYMBOL, value);
auto ast = parser_.ast();
MS_EXCEPTION_IF_NULL(ast);
py::tuple namespace_var = ast->CallParseModFunction(PYTHON_PARSE_GET_OPERATION_NAMESPACE_SYMBOL, value);
const size_t namespace_var_size = 2;
if (namespace_var.size() < namespace_var_size) {
MS_EXCEPTION(NameError) << "namespace_var is less than 2";
@ -200,28 +208,32 @@ AnfNodePtr FunctionBlock::MakeResolveOperation(const std::string &value) {
}
AnfNodePtr FunctionBlock::MakeResolve(const NameSpacePtr &name_space, const SymbolPtr &resolve_symbol) {
MS_LOG(DEBUG) << "MakeResolve for " << ((std::string)py::str(name_space->obj())) << " , "
<< ((std::string)resolve_symbol->symbol());
MS_LOG(DEBUG) << "MakeResolve for " << (name_space ? (std::string)py::str(name_space->obj()) : "null namespace")
<< " , " << (resolve_symbol ? (std::string)resolve_symbol->symbol() : "null resoleve symbol.");
ValueNodePtr module_node = NewValueNode(name_space);
ValueNodePtr symbol_node = NewValueNode(resolve_symbol);
auto node = func_graph()->NewCNodeInOrder({NewValueNode(prim::kPrimResolve), module_node, symbol_node});
auto node = func_graph_->NewCNodeInOrder({NewValueNode(prim::kPrimResolve), module_node, symbol_node});
return node;
}
// Add input for the block's phi parameter
void FunctionBlock::SetPhiArgument(const ParameterPtr &phi) {
MS_EXCEPTION_IF_NULL(phi);
TraceGuard trace_guard(std::make_shared<TraceResolve>(phi->debug_info()));
std::string var = phi_nodes_[phi];
MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " set phi " << phi->ToString() << " for var " << var;
MS_LOG(DEBUG) << "graph " << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " set phi " << phi->ToString()
<< " for var " << var;
auto removable = CollectRemovablePhi(phi);
// If the phi node is not necessary, not need to add to jumps_ of the prev blocks.
if (removable) {
MS_LOG(DEBUG) << "remove the phi when call graph " << func_graph_->ToString() << " var " << var;
MS_LOG(DEBUG) << "remove the phi when call graph " << (func_graph_ ? func_graph_->ToString() : "FG(Null)")
<< " var " << var;
return;
}
for (auto &pred : prev_blocks_) {
MS_EXCEPTION_IF_NULL(pred);
MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " pred_blocks_ " << pred->func_graph_->ToString();
MS_LOG(DEBUG) << "graph " << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " pred_blocks_ "
<< (pred->func_graph_ ? pred->func_graph_->ToString() : "FG(Null)");
AnfNodePtr arg_node = pred->ReadVariable(var);
CNodePtr jump = pred->jumps_[this];
MS_EXCEPTION_IF_NULL(jump);
@ -235,18 +247,18 @@ AnfNodePtr FunctionBlock::SearchReplaceNode(const std::string &var, const Parame
MS_EXCEPTION_IF_NULL(prev);
AnfNodePtr temp_node = prev->ReadVariable(var);
MS_EXCEPTION_IF_NULL(temp_node);
MS_LOG(DEBUG) << "graph " << prev->func_graph_->ToString() << " phi " << phi->ToString() << " for var " << var
<< " is " << temp_node->DebugString();
MS_LOG(DEBUG) << "graph " << (prev->func_graph_ ? prev->func_graph_->ToString() : "FG(Null)") << " phi "
<< (phi ? phi->ToString() : "null") << " for var " << var << " is " << temp_node->DebugString();
if (temp_node != phi) {
if (arg_node == nullptr) {
arg_node = temp_node;
MS_LOG(DEBUG) << "graph " << prev->func_graph_->ToString() << " phi " << phi->ToString()
<< " may be replaced by node " << arg_node->DebugString();
MS_LOG(DEBUG) << "graph " << (prev->func_graph_ ? prev->func_graph_->ToString() : "FG(Null)") << " phi "
<< (phi ? phi->ToString() : "null") << " may be replaced by node " << arg_node->DebugString();
} else if (temp_node == arg_node) {
MS_LOG(DEBUG) << "graph " << prev->func_graph_->ToString() << " phi " << phi->ToString() << " is same as node "
<< arg_node->DebugString();
MS_LOG(DEBUG) << "graph " << (prev->func_graph_ ? prev->func_graph_->ToString() : "FG(Null)") << " phi "
<< (phi ? phi->ToString() : "null") << " is same as node " << arg_node->DebugString();
} else {
MS_LOG(DEBUG) << "phi " << phi->ToString()
MS_LOG(DEBUG) << "phi " << (phi ? phi->ToString() : "null")
<< " cannot be removed as it assigns to different node. node1: " << arg_node->DebugString()
<< ", node2: " << temp_node->DebugString();
return nullptr;
@ -283,8 +295,8 @@ bool FunctionBlock::CollectRemovablePhi(const ParameterPtr &phi) {
AnfNodePtr arg_node = SearchReplaceNode(var, phi);
if (arg_node != nullptr) {
arg_node->set_debug_info(phi->debug_info());
MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " phi " << phi->ToString() << " can be replaced with "
<< arg_node->DebugString();
MS_LOG(DEBUG) << "graph " << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " phi " << phi->ToString()
<< " can be replaced with " << arg_node->DebugString();
// Replace var with new one. This equal to statement in TR "v0 is immediately replaced by v1."
WriteVariable(var, arg_node);
removable_phis_[phi] = arg_node;
@ -301,9 +313,10 @@ bool FunctionBlock::CollectRemovablePhi(const ParameterPtr &phi) {
if (phi_iter.second->isa<Parameter>()) {
const auto &param = phi_iter.second->cast<ParameterPtr>();
if (param == phi) {
MS_LOG(DEBUG) << "graph " << prev->func_graph_->ToString() << " var " << phi_iter.first->DebugString()
<< " can be replaced from " << param->DebugString() << " with " << arg_node->DebugString()
<< " in graph " << arg_node->func_graph()->ToString();
MS_LOG(DEBUG) << "graph " << (prev->func_graph_ ? prev->func_graph_->ToString() : "FG(Null)") << " var "
<< phi_iter.first->DebugString() << " can be replaced from " << param->DebugString()
<< " with " << arg_node->DebugString() << " in graph "
<< (arg_node->func_graph() ? arg_node->func_graph()->ToString() : "FG(Null)");
prev->removable_phis_[phi_iter.first] = arg_node;
}
}
@ -329,22 +342,25 @@ void FunctionBlock::Mature() {
// Force the condition node to bool using bool operation
CNodePtr FunctionBlock::ForceToBoolNode(const AnfNodePtr &cond) {
MS_EXCEPTION_IF_NULL(cond);
TraceGuard trace_guard(std::make_shared<TraceForceBool>(cond->debug_info()));
CNodePtr op_apply_node = func_graph()->NewCNodeInOrder({MakeResolveOperation(NAMED_PRIMITIVE_BOOL), cond});
CNodePtr op_apply_node = func_graph_->NewCNodeInOrder({MakeResolveOperation(NAMED_PRIMITIVE_BOOL), cond});
return op_apply_node;
}
CNodePtr FunctionBlock::ForceToWhileCond(const AnfNodePtr &cond) {
MS_EXCEPTION_IF_NULL(cond);
TraceGuard trace_guard(std::make_shared<TraceForceWhileCond>(cond->debug_info()));
CNodePtr op_apply_node = func_graph()->NewCNodeInOrder({MakeResolveOperation("while_cond"), cond});
CNodePtr op_apply_node = func_graph_->NewCNodeInOrder({MakeResolveOperation("while_cond"), cond});
return op_apply_node;
}
// Perform a jump from this block to target block
void FunctionBlock::Jump(const FunctionBlockPtr &target_block, const AnfNodePtr &node) {
if (func_graph()->get_return() != nullptr) {
MS_EXCEPTION_IF_NULL(target_block);
if (func_graph_->get_return() != nullptr) {
MS_LOG(EXCEPTION) << "Failure: have return node! NodeInfo: "
<< trace::GetDebugInfo(func_graph()->get_return()->debug_info());
<< trace::GetDebugInfo(func_graph_->get_return()->debug_info());
}
std::vector<AnfNodePtr> input_nodes;
input_nodes.emplace_back(NewValueNode(target_block->func_graph()));
@ -352,25 +368,27 @@ void FunctionBlock::Jump(const FunctionBlockPtr &target_block, const AnfNodePtr
input_nodes.emplace_back(node);
}
CNodePtr jump = func_graph()->NewCNodeInOrder(input_nodes);
CNodePtr jump = func_graph_->NewCNodeInOrder(input_nodes);
jumps_[target_block.get()] = jump;
target_block->AddPrevBlock(shared_from_this());
func_graph()->set_output(jump);
func_graph_->set_output(jump);
}
// Perform a conditional jump using switch operation.
// The first CNode select graph with condition, and than execute this graph
void FunctionBlock::ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr &true_block,
const FunctionBlockPtr &false_block, bool unroll_loop) {
if (func_graph()->get_return() != nullptr) {
MS_EXCEPTION_IF_NULL(true_block);
MS_EXCEPTION_IF_NULL(false_block);
if (func_graph_->get_return() != nullptr) {
MS_LOG(EXCEPTION) << "Failure: have return node! NodeInfo: "
<< trace::GetDebugInfo(func_graph()->get_return()->debug_info());
<< trace::GetDebugInfo(func_graph_->get_return()->debug_info());
}
CNodePtr switch_app =
func_graph()->NewCNodeInOrder({NewValueNode(prim::kPrimSwitch), condNode, NewValueNode(true_block->func_graph()),
NewValueNode(false_block->func_graph())});
CNodePtr switch_app_new = func_graph()->NewCNodeInOrder({switch_app});
func_graph()->set_output(switch_app_new);
func_graph_->NewCNodeInOrder({NewValueNode(prim::kPrimSwitch), condNode, NewValueNode(true_block->func_graph()),
NewValueNode(false_block->func_graph())});
CNodePtr switch_app_new = func_graph_->NewCNodeInOrder({switch_app});
func_graph_->set_output(switch_app_new);
}
// Create cnode for the assign statement like 'self.target = source'.
@ -381,7 +399,7 @@ void FunctionBlock::SetStateAssign(const AnfNodePtr &target, const AnfNodePtr &s
ValueNodePtr assign_op = NewValueNode(prim::GetPythonOps(primitive_name, module_name, true));
auto assign_node = func_graph_->NewCNodeInOrder({assign_op, target, source});
MS_LOG(DEBUG) << "Isolated node found(Assign), assign_node: " << assign_node->DebugString(2) << ", block: " << this
<< "/" << (func_graph() ? func_graph()->ToString() : "FG(Null)")
<< "/" << func_graph_->ToString()
<< ", Line: " << trace::GetDebugInfo(assign_node->debug_info(), "", kSourceLineTipDiscard);
AddIsolatedNode(assign_node);
}
@ -430,15 +448,15 @@ void FunctionBlock::AttachIsolatedNodesBeforeReturn() {
if (isolated_nodes_.empty()) {
return;
}
std::vector<AnfNodePtr> states;
states.emplace_back(NewValueNode(prim::kPrimMakeTuple));
for (auto &node : isolated_nodes_) {
MS_LOG(DEBUG) << "Adding dependency, node: " << node->DebugString(2) << " in " << func_graph()->ToString();
if (node->func_graph() == func_graph()) {
MS_EXCEPTION_IF_NULL(node);
MS_LOG(DEBUG) << "Adding dependency, node: " << node->DebugString(2) << " in " << func_graph_->ToString();
if (node->func_graph() == func_graph_) {
states.emplace_back(node);
} else {
MS_LOG(INFO) << "Ignored FV dependency, node: " << node->DebugString(2) << " in " << func_graph()->ToString();
MS_LOG(INFO) << "Ignored FV dependency, node: " << node->DebugString(2) << " in " << func_graph_->ToString();
}
}
isolated_nodes_.clear();
@ -452,11 +470,11 @@ void FunctionBlock::AttachIsolatedNodesBeforeReturn() {
// do not need to MakeTuple, just use the node.
state = states[1];
} else {
state = func_graph()->NewCNode(states);
state = func_graph_->NewCNode(states);
}
AnfNodePtr old_output = nullptr;
auto return_node = func_graph()->get_return();
auto return_node = func_graph_->get_return();
if (return_node) {
const size_t return_input_size = 2;
if (return_node->inputs().size() < return_input_size) {
@ -466,15 +484,15 @@ void FunctionBlock::AttachIsolatedNodesBeforeReturn() {
} else {
old_output = NewValueNode(kNone);
}
AnfNodePtr stop_grad_node = func_graph()->NewCNode({NewValueNode(prim::kPrimStopGradient), state});
CNodePtr depend_node = func_graph()->NewCNode({NewValueNode(prim::kPrimDepend), old_output, stop_grad_node});
AnfNodePtr stop_grad_node = func_graph_->NewCNode({NewValueNode(prim::kPrimStopGradient), state});
CNodePtr depend_node = func_graph_->NewCNode({NewValueNode(prim::kPrimDepend), old_output, stop_grad_node});
// We add this attribute for @constexpr use scene, since we must infer them before other nodes.
// That means isolated nodes will be evaluated first. It's not complete, but works in most scenes.
depend_node->AddAttr(kAttrTopoSortRhsFirst, MakeValue(true));
MS_EXCEPTION_IF_NULL(state);
MS_LOG(INFO) << "Attached for side-effect nodes, depend_node: " << depend_node->DebugString()
<< ", state: " << state->DebugString(2);
func_graph()->set_output(depend_node, true);
func_graph_->set_output(depend_node, true);
}
} // namespace parse
} // namespace mindspore

View File

@ -61,6 +61,7 @@ FuncGraphPtr ParsePythonCode(const py::object &obj, const std::string &python_mo
}
TypePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP32)) {
return kFloat32;
} else if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP16)) {
@ -72,6 +73,7 @@ TypePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph) {
// If any mixed precision flag add a cast node after the parameter node.
AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr &param) {
MS_EXCEPTION_IF_NULL(func_graph);
TypePtr dst_type;
if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP32)) {
dst_type = kFloat32;
@ -146,12 +148,18 @@ void Parser::CleanParserResource() {
void CheckFuncReturn(const FuncGraphPtr &fn, const std::shared_ptr<ParseAst> &ast) {
// Check whether the functions referred by this function and itself are missing 'return' statement
auto mng = Manage(fn, false);
MS_EXCEPTION_IF_NULL(ast);
for (const auto &func_graph : mng->func_graphs()) {
MS_EXCEPTION_IF_NULL(func_graph);
if (func_graph->get_return() != nullptr) {
continue;
}
py::object node = ast->GetAstNode();
py::list ret = ast->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node);
constexpr auto kMinListSize = 2;
if (ret.size() < kMinListSize) {
MS_LOG(EXCEPTION) << "list size:" << ret.size() << " is less than 2.";
}
py::str desc =
python_adapter::CallPyModFn(ast->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, ast->function(), ret[0], ret[1]);
MS_EXCEPTION(TypeError) << "Missing return statement in " << desc.cast<std::string>() << ".";
@ -178,49 +186,55 @@ FuncGraphPtr Parser::ParseFuncGraph() {
void Parser::GenerateArgsNodeForFunction(const FunctionBlockPtr &block, const py::object &fn_node) {
py::object func_args = python_adapter::GetPyObjAttr(fn_node, "args");
py::object var_arg_node = python_adapter::GetPyObjAttr(func_args, "vararg");
block->func_graph()->set_has_vararg(!py::isinstance<py::none>(var_arg_node));
MS_EXCEPTION_IF_NULL(block);
auto block_fg = block->func_graph();
block_fg->set_has_vararg(!py::isinstance<py::none>(var_arg_node));
py::object kw_arg_node = python_adapter::GetPyObjAttr(func_args, "kwarg");
block->func_graph()->set_has_kwarg(!py::isinstance<py::none>(kw_arg_node));
block_fg->set_has_kwarg(!py::isinstance<py::none>(kw_arg_node));
py::list kwonly_args = python_adapter::GetPyObjAttr(func_args, "kwonlyargs");
block->func_graph()->set_kwonlyargs_count(SizeToLong(kwonly_args.size()));
block_fg->set_kwonlyargs_count(SizeToLong(kwonly_args.size()));
MS_EXCEPTION_IF_NULL(ast_);
py::list args = ast_->GetArgs(fn_node);
for (std::size_t i = 0; i < args.size(); i++) {
std::string arg_name = py::cast<std::string>(args[i].attr("arg"));
if (ast()->target_type() == PARSE_TARGET_OBJECT_INSTANCE) {
if (ast_->target_type() == PARSE_TARGET_OBJECT_INSTANCE) {
if (arg_name == "self") {
continue;
}
}
TraceGuard guard(GetLocation(args[i]));
auto para_node = std::make_shared<Parameter>(block->func_graph());
auto para_node = std::make_shared<Parameter>(block_fg);
MS_EXCEPTION_IF_NULL(para_node);
para_node->set_name(arg_name);
para_node->debug_info()->set_name(arg_name);
block->func_graph()->add_parameter(para_node);
AnfNodePtr para_after_cast = GetMixedPrecisionCastHelp(block->func_graph(), para_node);
block_fg->add_parameter(para_node);
AnfNodePtr para_after_cast = GetMixedPrecisionCastHelp(block_fg, para_node);
block->WriteVariable(arg_name, para_after_cast);
MS_LOG(DEBUG) << "The arg[" << i << "] is " << arg_name;
}
}
void Parser::GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block, const py::object &fn_node) {
MS_EXCEPTION_IF_NULL(block);
py::list defaults = ast_->GetArgsDefaultValues(fn_node);
py::list args = ast_->GetArgs(fn_node);
std::vector<std::string> namelist_for_default_value;
std::vector<AnfNodePtr> default_values;
for (std::size_t i = 0; i < args.size(); i++) {
std::string arg_name = py::cast<std::string>(args[i].attr("arg"));
if (ast()->target_type() == PARSE_TARGET_OBJECT_INSTANCE) {
if (ast_->target_type() == PARSE_TARGET_OBJECT_INSTANCE) {
if (arg_name == "self") {
continue;
}
}
namelist_for_default_value.push_back(arg_name);
if (i >= defaults.size()) {
MS_LOG(EXCEPTION) << "Index:" << i << " out of range:" << defaults.size();
}
if (py::isinstance<py::none>(defaults[i])) {
default_values.push_back(NewValueNode(kNull));
} else {
@ -232,7 +246,7 @@ void Parser::GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block,
ScopePtr Parser::GetScopeForParseFunction() {
ScopePtr scope = ScopeManager::GetInstance().GetCurrentScope();
if (ast()->target_type() == PARSE_TARGET_OBJECT_INSTANCE) {
if (ast_->target_type() == PARSE_TARGET_OBJECT_INSTANCE) {
py::object scope_str = python_adapter::CallPyFn(PYTHON_MOD_PARSE_MODULE, PYTHON_PARSE_GET_SCOPE_NAME, ast_->obj());
if (!py::isinstance<py::none>(scope_str)) {
auto scope_name = py::cast<std::string>(scope_str);
@ -246,7 +260,7 @@ FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlo
ScopePtr scope = GetScopeForParseFunction();
// The node created in the parsefunction context, will inherit the scope created using scope_guard
ScopeGuard scope_guard(scope);
TraceGuard trace_guard(data_converter::GetObjKey(ast()->obj())[0], GetLocation(node));
TraceGuard trace_guard(data_converter::GetObjKey(ast_->obj())[0], GetLocation(node));
FunctionBlockPtr pFunBlock = MakeFunctionBlock(*this);
if (block != nullptr) {
pFunBlock->AddPrevBlock(block);
@ -257,14 +271,12 @@ FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlo
auto current_fg = pFunBlock->func_graph();
auto function_name = py::cast<std::string>(python_adapter::GetPyObjAttr(node, "name"));
MS_LOG(DEBUG) << "The function name is " << function_name;
current_fg->debug_info()->set_name(function_name);
MS_EXCEPTION_IF_NULL(ast_);
py::list deco_list = node.attr("decorator_list");
if (!deco_list.empty()) {
current_fg->debug_info()->set_deco_location(GetLocation(deco_list));
}
bool set_flag = UpdateFuncGraphFlags(ast_->function(), current_fg);
if (!ast_->obj().is(ast_->function())) {
set_flag = set_flag && UpdateFuncGraphFlags(ast_->obj(), current_fg);
@ -289,6 +301,7 @@ FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlo
// Add unused variables as isolate nodes.
for (auto &func_block : func_block_list_) {
MS_EXCEPTION_IF_NULL(func_block);
if (func_block->func_graph()->get_return() != nullptr) {
// Find unused variables.
func_block->FindIsolatedNodes();
@ -313,6 +326,7 @@ FunctionBlockPtr Parser::ParseStatements(FunctionBlockPtr block, const py::objec
for (size_t i = 0; i < count; ++i) {
auto node = node_list[i];
block = ParseStatement(block, node);
MS_EXCEPTION_IF_NULL(block);
// Insert appropriate depended items for the function block if it has a return node
if (block->func_graph()->get_return() != nullptr) {
// Skip statements after 'return' (or 'break', 'continue').
@ -352,9 +366,8 @@ AnfNodePtr Parser::ParseExprNode(const FunctionBlockPtr &block, const py::object
// Check the node type
AstMainType node_main_type = node_type->main_type();
if (node_main_type != AST_MAIN_TYPE_EXPR) {
MS_LOG(ERROR) << "Node type is error : " << node_main_type;
errcode_ = PARSE_NODE_TYPE_NO_MATCH;
return nullptr;
MS_LOG(EXCEPTION) << "Node type is error : " << node_main_type;
}
// Call the process function
std::string node_name = node_type->node_name();
@ -381,9 +394,15 @@ FunctionBlockPtr Parser::ParseExpr(const FunctionBlockPtr &block, const py::obje
// False, None, None
//
// Check the expand info result
if (expand_info.empty()) {
MS_LOG(EXCEPTION) << "Empty expand_info.";
}
auto is_expand = py::cast<bool>(expand_info[0]);
if (is_expand) {
// Process the expr statement
if (expand_info.size() < 2) {
MS_LOG(EXCEPTION) << "expand_info size:" << expand_info.size() << " less than 2.";
}
py::object value_object = expand_info[1];
// Make a Expr CNode.
AnfNodePtr call_node = ParseExprNode(block, value_object);
@ -429,6 +448,8 @@ LocationPtr Parser::GetLocation(const py::object &node) const {
void Parser::MakeConditionBlocks(const FunctionBlockPtr &pre_block, const FunctionBlockPtr &true_block,
const FunctionBlockPtr &false_block) {
MS_EXCEPTION_IF_NULL(true_block);
MS_EXCEPTION_IF_NULL(false_block);
true_block->AddPrevBlock(pre_block);
true_block->Mature();
@ -445,10 +466,9 @@ FunctionBlockPtr Parser::ParseReturn(const FunctionBlockPtr &block, const py::ob
py::object value = python_adapter::GetPyObjAttr(node, "value");
AnfNodePtr pReturnStatementNode = ParseExprNode(block, value);
// Create the cnode
CNodePtr pReturnCNode = block->func_graph()->NewCNodeInOrder({pReturnValueNode, pReturnStatementNode});
block->func_graph()->set_return(pReturnCNode);
auto block_fg = block->func_graph();
CNodePtr pReturnCNode = block_fg->NewCNodeInOrder({pReturnValueNode, pReturnStatementNode});
block_fg->set_return(pReturnCNode);
return block;
}
@ -462,17 +482,17 @@ AnfNodePtr Parser::ParseBinOp(const FunctionBlockPtr &block, const py::object &n
// Create left and right ANF node
AnfNodePtr left_node = ParseExprNode(block, left);
if (left_node == nullptr) {
MS_LOG(WARNING) << "DoBinOp process left node failed: " << errcode();
return nullptr;
MS_LOG(EXCEPTION) << "DoBinOp process left node failed: " << errcode();
}
AnfNodePtr right_node = ParseExprNode(block, right);
if (right_node == nullptr) {
MS_LOG(WARNING) << "DoBinOp process right node failed:" << errcode();
return nullptr;
MS_LOG(EXCEPTION) << "DoBinOp process right node failed:" << errcode();
}
// Resolve the op
MS_EXCEPTION_IF_NULL(block);
AnfNodePtr op_node = block->MakeResolveAstOp(op);
// Create apply node
MS_EXCEPTION_IF_NULL(block->func_graph());
return block->func_graph()->NewCNodeInOrder({op_node, left_node, right_node});
}
@ -480,6 +500,7 @@ AnfNodePtr Parser::ParseName(const FunctionBlockPtr &block, const py::object &no
MS_LOG(DEBUG) << "Process ast Name";
auto name_id = py::cast<std::string>(python_adapter::GetPyObjAttr(node, "id"));
MS_LOG(DEBUG) << "The Name id is " << name_id;
MS_EXCEPTION_IF_NULL(block);
if (block->IsGlobalVar(name_id)) {
return block->MakeResolveSymbol(name_id);
}
@ -509,9 +530,8 @@ AnfNodePtr Parser::ParseNum(const FunctionBlockPtr &, const py::object &node) {
return NewValueNode(data);
} else {
// no else actually
MS_LOG(ERROR) << "Unsupported Num type : " << (std::string)py::str(obj);
errcode_ = PARSE_NODE_TYPE_UNKNOWN;
return nullptr;
MS_LOG(EXCEPTION) << "Unsupported Num type : " << (std::string)py::str(obj);
}
}
@ -558,14 +578,13 @@ AnfNodePtr Parser::ParseNameConstant(const FunctionBlockPtr &, const py::object
} else if (py::isinstance<py::none>(obj)) {
MS_LOG(INFO) << "The NameConstant is none:" << (std::string)py::str(obj);
return NewValueNode(kNone);
} else {
// no else actually
MS_LOG(ERROR) << "Unsupported NameConstant type: " << (std::string)py::str(obj);
errcode_ = PARSE_NODE_TYPE_UNKNOWN;
return nullptr;
}
// no else actually
errcode_ = PARSE_NODE_TYPE_UNKNOWN;
MS_LOG(EXCEPTION) << "Unsupported NameConstant type: " << (std::string)py::str(obj);
}
AnfNodePtr Parser::GenerateMakeTuple(const FunctionBlockPtr &block, const std::vector<AnfNodePtr> &element_nodes) {
MS_EXCEPTION_IF_NULL(block);
AnfNodePtr make_tuple_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE);
std::vector<AnfNodePtr> make_tuple_nodes;
make_tuple_nodes.push_back(make_tuple_op);
@ -575,6 +594,7 @@ AnfNodePtr Parser::GenerateMakeTuple(const FunctionBlockPtr &block, const std::v
}
AnfNodePtr Parser::ParseSuper(const FunctionBlockPtr &block, const py::list &args) {
MS_EXCEPTION_IF_NULL(block);
py::object father_class;
const size_t expect_args_size = 2;
if (args.empty()) {
@ -588,7 +608,7 @@ AnfNodePtr Parser::ParseSuper(const FunctionBlockPtr &block, const py::list &arg
} else {
MS_EXCEPTION(ArgumentError) << "When call 'super', the args number should be 0 or 2, but got" << args.size() << ".";
}
py::object target_class_instance = ast()->CallParserObjMethod(PYTHON_PARSE_ANALYZE_SUPER, father_class, ast()->obj());
py::object target_class_instance = ast_->CallParserObjMethod(PYTHON_PARSE_ANALYZE_SUPER, father_class, ast_->obj());
py::object namespace_var = ast_->CallParseModFunction(PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, target_class_instance);
NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var);
SymbolPtr symbol = std::make_shared<Symbol>("namespace");
@ -626,6 +646,7 @@ AnfNodePtr Parser::ParseCall(const FunctionBlockPtr &block, const py::object &no
CNodePtr MakeUnpackCall(const FuncGraphPtr &func_graph, const AnfNodePtr &call_function_anf_node,
const std::vector<AnfNodePtr> &packed_arguments) {
MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> unpack_call_nodes;
auto unpack_call_op = NewValueNode(std::make_shared<prim::UnpackCall>(NAMED_METAGRAPH_UNPACKCALL));
unpack_call_nodes.push_back(unpack_call_op);
@ -640,6 +661,7 @@ AnfNodePtr Parser::GenerateAnfNodeForCall(const FunctionBlockPtr &block, const A
const std::vector<AnfNodePtr> &packed_arguments,
const std::vector<AnfNodePtr> &group_arguments, bool need_unpack) const {
// If there is keyword arguments or starred, using an unpack_call op to unpack the argument
MS_EXCEPTION_IF_NULL(block);
if (need_unpack) {
return MakeUnpackCall(block->func_graph(), call_function_anf_node, packed_arguments);
}
@ -654,6 +676,8 @@ AnfNodePtr Parser::GenerateAnfNodeForCall(const FunctionBlockPtr &block, const A
bool Parser::ParseArgsInCall(const FunctionBlockPtr &block, const py::list &args,
std::vector<AnfNodePtr> *packed_arguments, std::vector<AnfNodePtr> *group_arguments) {
MS_EXCEPTION_IF_NULL(packed_arguments);
MS_EXCEPTION_IF_NULL(group_arguments);
bool need_unpack = false;
for (size_t i = 0; i < args.size(); i++) {
auto arg_node = AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, args[i])));
@ -679,6 +703,8 @@ bool Parser::ParseKeywordsInCall(const FunctionBlockPtr &block, const py::object
bool need_unpack = false;
py::list keywords = python_adapter::GetPyObjAttr(node, "keywords");
if (!keywords.empty()) {
MS_EXCEPTION_IF_NULL(block);
MS_EXCEPTION_IF_NULL(packed_arguments);
need_unpack = true;
std::vector<AnfNodePtr> keys;
std::vector<AnfNodePtr> values;
@ -708,14 +734,14 @@ bool Parser::ParseKeywordsInCall(const FunctionBlockPtr &block, const py::object
// Process call attributes of class type define, eg: x.y()
AnfNodePtr Parser::ParseAttribute(const FunctionBlockPtr &block, const py::object &node) {
MS_LOG(DEBUG) << "Process ast Attribute";
// Process class value,eg: self.xx
if (ast()->target_type() == PARSE_TARGET_OBJECT_INSTANCE) {
if (ast_->target_type() == PARSE_TARGET_OBJECT_INSTANCE) {
if (ast_->IsClassMember(node)) {
std::string var_name = "self.";
std::string attr_name = node.attr("attr").cast<std::string>();
(void)var_name.append(attr_name);
auto attr_obj = ast()->obj().attr(attr_name.c_str());
MS_EXCEPTION_IF_NULL(block);
if (py::hasattr(ast()->obj(), attr_name.c_str()) &&
(py::hasattr(attr_obj, PYTHON_PRIMITIVE_FLAG) || py::isinstance<py::int_>(attr_obj) ||
py::isinstance<py::float_>(attr_obj) || py::isinstance<py::bool_>(attr_obj) ||
@ -736,8 +762,7 @@ AnfNodePtr Parser::ParseAttribute(const FunctionBlockPtr &block, const py::objec
py::object value_body = python_adapter::GetPyObjAttr(node, "value");
AnfNodePtr value_node = ParseExprNode(block, value_body);
if (value_node == nullptr) {
MS_LOG(WARNING) << "Parse attribute failed";
return nullptr;
MS_LOG(EXCEPTION) << "Parse attribute failed";
}
// Process the node attr
@ -761,24 +786,30 @@ AnfNodePtr Parser::ParseCompare(const FunctionBlockPtr &block, const py::object
// For python comparison ,there may be if x>y>5 ,
// Which there is two ops , but we only support one now
py::list ops = python_adapter::GetPyObjAttr(node, "ops");
if (ops.size() > MAX_COMPARISON_OPS_SUPPORTED) {
MS_EXCEPTION(NotSupportError)
<< "MindSpore does not support comparison with operators more than one now, ops size =" << ops.size();
if (ops.size() != MAX_COMPARISON_OPS_SUPPORTED) {
MS_EXCEPTION(NotSupportError) << "MindSpore only support comparison with operators with one now, ops size ="
<< ops.size();
}
py::object left = python_adapter::GetPyObjAttr(node, "left");
py::list comparators = python_adapter::GetPyObjAttr(node, "comparators");
if (comparators.empty()) {
MS_LOG(EXCEPTION) << "Comparators can't be empty.";
}
AnfNodePtr left_node = ParseExprNode(block, left);
AnfNodePtr right_node = ParseExprNode(block, comparators[0]);
MS_EXCEPTION_IF_NULL(block);
AnfNodePtr op_node = block->MakeResolveAstOp(ops[0]);
return block->func_graph()->NewCNodeInOrder({op_node, left_node, right_node});
}
AnfNodePtr Parser::ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, AstSubType mode) {
// If there is only one bool op now
MS_EXCEPTION_IF_NULL(block);
if (value_list.empty()) {
MS_LOG(EXCEPTION) << "value list is empty.";
}
if (value_list.size() == 1) {
AnfNodePtr first_node = ParseExprNode(block, value_list[0]);
return first_node;
@ -788,15 +819,15 @@ AnfNodePtr Parser::ProcessBoolOpValueList(const FunctionBlockPtr &block, const p
for (size_t i = 1; i < value_list.size(); i++) {
rest.append(value_list[i]);
}
MS_EXCEPTION_IF_NULL(block);
FunctionBlockPtr true_block = nullptr;
FunctionBlockPtr false_block = nullptr;
auto block_fg = block->func_graph();
{
TraceGuard guard(std::make_shared<TraceIfExpTrueBranch>(block->func_graph()->debug_info()));
TraceGuard guard(std::make_shared<TraceIfExpTrueBranch>(block_fg->debug_info()));
true_block = MakeFunctionBlock(*this);
}
{
TraceGuard guard(std::make_shared<TraceIfExpFalseBranch>(block->func_graph()->debug_info()));
TraceGuard guard(std::make_shared<TraceIfExpFalseBranch>(block_fg->debug_info()));
false_block = MakeFunctionBlock(*this);
}
MakeConditionBlocks(block, true_block, false_block);
@ -820,12 +851,12 @@ AnfNodePtr Parser::ProcessBoolOpValueList(const FunctionBlockPtr &block, const p
b2->func_graph()->set_output(test_node);
auto cond_node = block->ForceToBoolNode(test_node);
auto switch_app = block->func_graph()->NewCNodeInOrder({NewValueNode(prim::kPrimSwitch), cond_node,
NewValueNode(true_block->func_graph()),
NewValueNode(false_block->func_graph())});
auto switch_app =
block_fg->NewCNodeInOrder({NewValueNode(prim::kPrimSwitch), cond_node, NewValueNode(true_block->func_graph()),
NewValueNode(false_block->func_graph())});
std::vector<AnfNodePtr> call_graph_nodes{switch_app};
auto switch_app_call = block->func_graph()->NewCNodeInOrder(call_graph_nodes);
auto switch_app_call = block_fg->NewCNodeInOrder(call_graph_nodes);
return switch_app_call;
}
}
@ -836,8 +867,7 @@ AnfNodePtr Parser::ParseBoolOp(const FunctionBlockPtr &block, const py::object &
py::object op_node = python_adapter::GetPyObjAttr(node, "op");
AstSubType op_type = ast_->GetOpType(op_node);
if (op_type == AST_SUB_TYPE_UNKNOWN) {
MS_LOG(WARNING) << "ProcessBoolOp, got unknown op type";
return nullptr;
MS_LOG(EXCEPTION) << "ProcessBoolOp, got unknown op type";
}
py::list op_values = python_adapter::GetPyObjAttr(node, "values");
return ProcessBoolOpValueList(block, op_values, op_type);
@ -866,20 +896,21 @@ AnfNodePtr Parser::ParseLambda(const FunctionBlockPtr &block, const py::object &
// Get lambda args
py::list args = ast_->GetArgs(node);
auto block_fg = func_block->func_graph();
for (std::size_t i = 0; i < args.size(); i++) {
std::string arg = py::cast<std::string>(args[i].attr("arg"));
TraceGuard guard(GetLocation(args[i]));
auto para_node = std::make_shared<Parameter>(func_block->func_graph());
auto para_node = std::make_shared<Parameter>(block_fg);
para_node->debug_info()->set_name(arg);
func_block->func_graph()->add_parameter(para_node);
block_fg->add_parameter(para_node);
func_block->WriteVariable(arg, para_node);
MS_LOG(DEBUG) << "The arg[" << i << "] is " << arg;
}
py::object body_node = python_adapter::GetPyObjAttr(node, "body");
AnfNodePtr lambda_body_node = ParseExprNode(func_block, body_node);
func_block->func_graph()->set_output(lambda_body_node);
ValueNodePtr const_graph = NewValueNode(func_block->func_graph());
block_fg->set_output(lambda_body_node);
ValueNodePtr const_graph = NewValueNode(block_fg);
return const_graph;
}
@ -888,7 +919,7 @@ AnfNodePtr Parser::ParseTuple(const FunctionBlockPtr &block, const py::object &n
MS_LOG(DEBUG) << "Process ast Tuple";
MS_EXCEPTION_IF_NULL(block);
py::tuple elts = python_adapter::GetPyObjAttr(node, "elts");
if (elts.size() == 0) {
if (elts.empty()) {
auto empty_tuple = std::vector<ValuePtr>();
return NewValueNode(std::make_shared<ValueTuple>(empty_tuple));
}
@ -909,7 +940,7 @@ AnfNodePtr Parser::ParseList(const FunctionBlockPtr &block, const py::object &no
MS_LOG(DEBUG) << "Process ast List";
MS_EXCEPTION_IF_NULL(block);
py::list elts = python_adapter::GetPyObjAttr(node, "elts");
if (elts.size() == 0) {
if (elts.empty()) {
auto empty_list = std::vector<ValuePtr>();
return NewValueNode(std::make_shared<ValueList>(empty_list));
}
@ -934,7 +965,6 @@ AnfNodePtr Parser::ParseSubscript(const FunctionBlockPtr &block, const py::objec
py::object slice_node = python_adapter::GetPyObjAttr(node, "slice");
AnfNodePtr value = ParseExprNode(block, value_node);
AnfNodePtr slice = ParseExprNode(block, slice_node);
return block->func_graph()->NewCNodeInOrder({op_getitem, value, slice});
}
@ -949,7 +979,6 @@ AnfNodePtr Parser::ParseSlice(const FunctionBlockPtr &block, const py::object &n
AnfNodePtr start_node = ParseExprNode(block, start);
AnfNodePtr stop_node = ParseExprNode(block, stop);
AnfNodePtr step_node = ParseExprNode(block, step);
return block->func_graph()->NewCNodeInOrder({op_makeslice, start_node, stop_node, step_node});
}
@ -1060,12 +1089,13 @@ FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object
FunctionBlockPtr true_block = nullptr;
FunctionBlockPtr false_block = nullptr;
auto block_fg = block->func_graph();
{
TraceGuard guard(std::make_shared<TraceIfStmtTrueBranch>(block->func_graph()->debug_info()));
TraceGuard guard(std::make_shared<TraceIfStmtTrueBranch>(block_fg->debug_info()));
true_block = MakeFunctionBlock(*this);
}
{
TraceGuard guard(std::make_shared<TraceIfStmtFalseBranch>(block->func_graph()->debug_info()));
TraceGuard guard(std::make_shared<TraceIfStmtFalseBranch>(block_fg->debug_info()));
false_block = MakeFunctionBlock(*this);
}
@ -1073,7 +1103,7 @@ FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object
FunctionBlockPtr after_block = nullptr;
{
TraceGuard guard(std::make_shared<TraceIfStmtAfterBranch>(block->func_graph()->debug_info()));
TraceGuard guard(std::make_shared<TraceIfStmtAfterBranch>(block_fg->debug_info()));
after_block = MakeFunctionBlock(*this);
}
@ -1178,7 +1208,8 @@ FunctionBlockPtr Parser::GenerateBlockInFor(const TraceInfoPtr &trace_info) {
int64_t Parser::GetForTransToWhileLoop() {
// int64 support 63bits positive num mostly.
if (max_for_loop_count_str_.size() > 63 || max_for_loop_count_str_.empty()) {
constexpr auto kMaxNumLength = 10;
if (max_for_loop_count_str_.size() > kMaxNumLength || max_for_loop_count_str_.empty()) {
return MAX_FOR_LOOP_COUNT;
}
if (std::any_of(max_for_loop_count_str_.begin(), max_for_loop_count_str_.end(),
@ -1634,6 +1665,7 @@ FunctionBlockPtr Parser::ParsePass(const FunctionBlockPtr &block, const py::obje
}
AnfNodePtr FindPhis(const std::unordered_map<ParameterPtr, AnfNodePtr> &removable_phis, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
const auto &inp = node->cast<ParameterPtr>();
const auto &iter = removable_phis.find(inp);
if (iter == removable_phis.end()) {
@ -1675,6 +1707,7 @@ void Parser::RemoveUnnecessaryPhis() {
std::vector<AnfNodePtr> new_parameters(parameters.size());
auto it = std::copy_if(
parameters.begin(), parameters.end(), new_parameters.begin(), [&local_removable_phis](const AnfNodePtr &param) {
MS_EXCEPTION_IF_NULL(param);
return local_removable_phis.find(param->cast<ParameterPtr>()) == local_removable_phis.end();
});
@ -1762,8 +1795,7 @@ AstNodeTypePtr ParseAst::GetNodeType(const py::object &node) {
py::list list_value = python_adapter::CallPyModFn(module_, PYTHON_PARSE_GET_NODE_TYPE, node);
const size_t list_value_size = 2;
if (list_value.size() < list_value_size) {
MS_LOG(ERROR) << "The node of python method must has 2 values.";
return nullptr;
MS_LOG(EXCEPTION) << "The node of python method must has 2 values.";
}
auto node_name = py::cast<std::string>(list_value[0]);
auto type = AstMainType(py::cast<int32_t>(list_value[1]));
@ -1828,6 +1860,7 @@ static AnfNodePtr CopyNodesFromParamDefaultValue(const FuncGraphPtr &func_graph,
std::size_t index = 0;
std::vector<AnfNodePtr> old_cnodes;
old_cnodes.emplace_back(param_node);
MS_EXCEPTION_IF_NULL(func_graph);
auto res = func_graph->NewCNodeInOrder({});
std::vector<CNodePtr> new_cnodes;
new_cnodes.emplace_back(res);
@ -1835,7 +1868,6 @@ static AnfNodePtr CopyNodesFromParamDefaultValue(const FuncGraphPtr &func_graph,
auto current = old_cnodes[index];
auto current_new_cnode = new_cnodes[index];
index++;
MS_EXCEPTION_IF_NULL(current);
if (current->isa<CNode>()) {
auto &inputs = current->cast<CNodePtr>()->inputs();
for (auto it = inputs.begin(); it != inputs.end(); it++) {

View File

@ -63,6 +63,8 @@ using mindspore::abstract::AnalysisContextPtr;
using mindspore::validator::Validate;
namespace {
void DoRenormalize(const bool &changed, const FuncGraphPtr &func_graph, const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(res);
abstract::AbstractBasePtrList args_spec;
auto parameters = func_graph->parameters();
(void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec),
@ -76,15 +78,16 @@ void DoRenormalize(const bool &changed, const FuncGraphPtr &func_graph, const Re
} // namespace
bool SimplifyDataStructuresPass(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res->func_graph());
MS_EXCEPTION_IF_NULL(res);
FuncGraphPtr func_graph = res->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
bool changed = opt::SimplifyDataStructures(func_graph, res->manager());
DoRenormalize(changed, func_graph, res);
return true;
}
bool TransformTopGraphPass(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res);
if (res->func_graph() == nullptr) {
MS_LOG(EXCEPTION) << "Transform top graph error.";
}
@ -103,15 +106,16 @@ bool TransformTopGraphPass(const ResourcePtr &res) {
}
bool CleanAfterOptAPass(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res->func_graph());
MS_EXCEPTION_IF_NULL(res);
FuncGraphPtr func_graph = res->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
bool changed = opt::CleanAfterOptA(func_graph, res->manager());
DoRenormalize(changed, func_graph, res);
return true;
}
FuncGraphPtr PrimBpOptPassStep1(const opt::irpass::OptimizeIRPassLib &irpass, const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res);
MS_EXCEPTION_IF_NULL(res->func_graph());
opt::OptPassConfig pynative_eliminate = opt::OptPassConfig({
irpass.pynative_eliminate_,
@ -137,6 +141,7 @@ FuncGraphPtr PrimBpOptPassStep1(const opt::irpass::OptimizeIRPassLib &irpass, co
}
FuncGraphPtr PrimBpOptPassStep2(const opt::irpass::OptimizeIRPassLib &irpass, const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res);
MS_EXCEPTION_IF_NULL(res->func_graph());
opt::OptPassConfig special_op_simplify = opt::OptPassConfig({
irpass.switch_simplify_,
@ -513,6 +518,7 @@ void ReclaimOptimizer() {
}
bool OptPassGroup(const ResourcePtr &res, const std::string &name) {
MS_EXCEPTION_IF_NULL(res);
if (res->func_graph() == nullptr) {
MS_LOG(ERROR) << "Opt passes int64_t error";
return false;
@ -551,6 +557,7 @@ bool AddRecomputationPass(const ResourcePtr &res) {
}
bool AddCacheEmbeddingPass(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res);
#if ((defined ENABLE_CPU) && (!defined _WIN32))
if (ps::PSContext::instance()->is_ps_mode()) {
return true;
@ -571,6 +578,7 @@ bool AddCacheEmbeddingPass(const ResourcePtr &res) {
}
bool RemoveValueNodeDuplicationsPass(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res);
if (res->func_graph() == nullptr) {
MS_LOG(EXCEPTION) << "Remove value node duplications error.";
}
@ -588,6 +596,7 @@ bool RemoveValueNodeDuplicationsPass(const ResourcePtr &res) {
}
bool CconvPass(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res);
MS_EXCEPTION_IF_NULL(res->func_graph());
FuncGraphPtr func_graph = res->func_graph();
FuncGraphPtr new_fg = LiftingClone(func_graph);
@ -602,6 +611,7 @@ void UpdateFuncGraphParameter(const FuncGraphPtr &func_graph) {
std::vector<AnfNodePtr> new_paras;
for (const auto &param : func_graph->parameters()) {
auto param_node = param->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param_node);
if (param_node->has_default()) {
new_paras.push_back(param_node);
continue;
@ -618,6 +628,7 @@ void UpdateFuncGraphParameter(const FuncGraphPtr &func_graph) {
}
bool ValidatePass(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res);
MS_EXCEPTION_IF_NULL(res->func_graph());
FuncGraphPtr func_graph = res->func_graph();
Validate(func_graph);

View File

@ -143,6 +143,7 @@ std::string GetCompileExceptionInfo() {
}
void SetGpuLoopSink(const ResourcePtr &resource) {
MS_EXCEPTION_IF_NULL(resource);
auto func_graph = resource->func_graph();
if (func_graph != nullptr && func_graph->manager() != nullptr) {
auto manager = func_graph->manager();
@ -452,13 +453,17 @@ ExecutorPy::~ExecutorPy() {
void ExecutorPy::GetWeightInfo(const CNodePtr &root_node, const AnfNodePtr &weight_node,
std::map<std::string, std::pair<PrimitivePyAdapterPtr, std::string>> *fake_quant_table) {
MS_EXCEPTION_IF_NULL(root_node);
MS_EXCEPTION_IF_NULL(fake_quant_table);
std::string weight_name;
auto x = root_node->input(1);
MS_EXCEPTION_IF_NULL(x);
if (IsPrimitiveCNode(weight_node, prim::kPrimLoad)) {
weight_name = weight_node->cast<CNodePtr>()->input(1)->cast<ParameterPtr>()->name();
} else {
weight_name = weight_node->cast<ParameterPtr>()->name();
auto para = weight_node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(para);
weight_name = para->name();
}
// find the fakequant from input
int64_t count = 0;
@ -501,9 +506,12 @@ void ExecutorPy::GetWeightInfo(const CNodePtr &root_node, const AnfNodePtr &weig
if (IsPrimitiveCNode(fakequant_min_node, prim::kPrimLoad)) {
fakequant_min_node_name = fakequant_min_node->cast<CNodePtr>()->input(1)->cast<ParameterPtr>()->name();
} else {
fakequant_min_node_name = fakequant_min_node->cast<ParameterPtr>()->name();
auto param = fakequant_min_node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param);
fakequant_min_node_name = param->name();
}
auto quant_op_value = cnode->input(0)->cast<ValueNodePtr>()->value();
MS_EXCEPTION_IF_NULL(quant_op_value);
if (!quant_op_value->isa<PrimitivePy>()) {
return;
}
@ -602,6 +610,7 @@ bool IsPhaseTrain(const std::string &phase_s) {
}
std::vector<ActionItem> GetPipeline(const ResourcePtr &resource, const std::string &phase_s, bool use_vm) {
MS_EXCEPTION_IF_NULL(resource);
bool is_air = IsPhaseExportAir(phase_s);
std::string backend = MsContext::GetInstance()->backend_policy();

View File

@ -69,6 +69,7 @@ static int64_t InferStage(int64_t rank_id, int64_t stage_num, int64_t device_num
// Only auto_parallel and semi_auto_parallel support PipelineSplit
bool PipelineSplit(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res);
auto parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
if (parallel_mode != parallel::SEMI_AUTO_PARALLEL && parallel_mode != parallel::AUTO_PARALLEL) {
MS_LOG(INFO) << "Only auto_parallel and semi_auto_parallel support pipeline split.";

View File

@ -54,6 +54,7 @@ void PrimBpropOptGraphLevel2Info::TryFreeOneValue(const ValuePtrList &op_args,
}
void PrimBpropOptGraphLevel2Info::AnalysisArgUsingInfo(const FuncGraphManagerPtr &manager) {
MS_EXCEPTION_IF_NULL(manager);
if (analysis_finish_flg_) {
return;
}
@ -106,7 +107,9 @@ void PrimBpropOptGraphLevel2Info::AalysisForTupleGetItem(const NodeUsersMap &nod
const std::shared_ptr<AnfNode> &param,
ParamUsingInfo *arg_info, const AnfNodePtr &user_node) const {
MS_EXCEPTION_IF_NULL(arg_info);
MS_EXCEPTION_IF_NULL(user_node);
auto cnode = user_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
const size_t tuple_get_item_size = 3;
const size_t index = 2;
if (cnode->size() != tuple_get_item_size) {
@ -140,6 +143,7 @@ void PrimBpropOptGraphLevel2Info::AalysisForTupleGetItem(const NodeUsersMap &nod
void PrimBpropOptGraphLevel2Info::ArgInfoRefresh(const std::shared_ptr<AnfNode> &param,
ParamUsingInfo *arg_info) const {
MS_EXCEPTION_IF_NULL(arg_info);
MS_EXCEPTION_IF_NULL(param);
auto abs = param->abstract();
MS_EXCEPTION_IF_NULL(abs);
if (abs->isa<abstract::AbstractTensor>()) {
@ -201,6 +205,7 @@ FuncGraphPtr PrimBpropOptimizer::OptimizeBPropFuncGraph(const FuncGraphPtr &bpro
FuncGraphPtr PrimBpropOptimizer::GetOptBpropFromCache(const FuncGraphPtr &bprop_fg, const ValuePtrList &op_args,
const ValuePtr &out, const PrimitivePtr &prim) {
MS_EXCEPTION_IF_NULL(bprop_fg);
abstract::AbstractBasePtrList abs_list;
ArgsToAbs(prim, op_args, &abs_list);
@ -233,6 +238,7 @@ FuncGraphPtr PrimBpropOptimizer::GetOptBpropFromCache(const FuncGraphPtr &bprop_
FuncGraphPtr PrimBpropOptimizer::GenSpecOptBprop(const FuncGraphPtr &bprop_fg, const ValuePtrList &op_args,
const ValuePtr &out, const PrimitivePtr &prim, bool hook_flg) {
MS_EXCEPTION_IF_NULL(bprop_fg);
abstract::AbstractBasePtrList abs_list;
ArgsToAbs(prim, op_args, &abs_list);
if (!hook_flg) {
@ -272,6 +278,7 @@ PrimBpropOptGraphInfoPtr PrimBpropOptimizer::PrimBpropOptStep1(const FuncGraphPt
void PrimBpropOptimizer::BindAbsToParameters(const FuncGraphPtr &bprop_fg,
const abstract::AbstractBasePtrList &abs_list_input) {
MS_EXCEPTION_IF_NULL(bprop_fg);
auto &params = bprop_fg->parameters();
if (abs_list_input.size() != params.size()) {
MS_LOG(EXCEPTION) << "Param num:" << params.size() << " not match inputs num " << abs_list_input.size();
@ -306,6 +313,9 @@ ECacheQrtRes PrimBpropOptimizer::GetOptBpfgFromCache(const PrimitivePtr &prim,
const abstract::AbstractBasePtrList &abs_list,
PrimBpropOptGraphLevel2InfoPtr *level_2_graph_info,
PrimBpropOptGraphInfoPtr *level_1_graph_info) {
MS_EXCEPTION_IF_NULL(prim);
MS_EXCEPTION_IF_NULL(level_1_graph_info);
MS_EXCEPTION_IF_NULL(level_2_graph_info);
auto attrs_ = prim->attrs();
for (auto &item : attrs_) {
MS_LOG(DEBUG) << "prim:" << prim->ToString() << " attr: " << item.first << " value:" << item.second->ToString();
@ -327,6 +337,8 @@ ECacheQrtRes PrimBpropOptimizer::GetOptBpfgFromCache(const PrimitivePtr &prim,
void PrimBpropOptimizer::ArgsToAbs(const PrimitivePtr &prim, const ValuePtrList &op_args,
abstract::AbstractBasePtrList *abs_list) {
MS_EXCEPTION_IF_NULL(prim);
MS_EXCEPTION_IF_NULL(abs_list);
auto const_input_index = prim->get_const_input_indexes();
bool have_const_input = !const_input_index.empty();
bool is_const_prim = prim->is_const_prim();
@ -345,6 +357,7 @@ void PrimBpropOptimizer::ArgsToAbs(const PrimitivePtr &prim, const ValuePtrList
abstract::AbstractBasePtrList PrimBpropOptimizer::AddOutToAbsList(const ValuePtr &out,
const abstract::AbstractBasePtrList &abs_list) {
MS_EXCEPTION_IF_NULL(out);
if (!out->isa<tensor::Tensor>() && !out->isa<ValueTuple>()) {
MS_LOG(EXCEPTION) << "Out value not Tensor or Tuple, please check the input arguments.";
}

View File

@ -29,6 +29,8 @@ namespace mindspore {
namespace pipeline {
void TryToDoReplace(FuncGraphManager *const manager, const AnfNodePtr &node, HashCache *const hash_cache,
HashValue *const hash_value) {
MS_EXCEPTION_IF_NULL(manager);
MS_EXCEPTION_IF_NULL(hash_cache);
const auto &to_check_value = GetValueNode(node);
MS_EXCEPTION_IF_NULL(to_check_value);

View File

@ -72,8 +72,10 @@ void AnalysisSchedule::SetNextRunnableImpl() {
return;
}
// Check if enter endless loop
auto it = std::find_if(asyncAbstractList_.begin(), asyncAbstractList_.end(),
[](const auto &item) { return item->HasResult(); });
auto it = std::find_if(asyncAbstractList_.begin(), asyncAbstractList_.end(), [](const auto &item) {
MS_EXCEPTION_IF_NULL(item);
return item->HasResult();
});
if (it == asyncAbstractList_.end()) {
// Enter endless loop if there is not ready result.
MS_LOG(EXCEPTION) << "Enter endless loop. There isn't any branch that can been evaluated. Please check the code.";
@ -184,6 +186,7 @@ void AnalysisResultCacheMgr::Todo() {
std::lock_guard<std::mutex> lock(todo_lock_);
while (!todo_.empty()) {
AnfNodeConfigPtr conf = todo_.front();
MS_EXCEPTION_IF_NULL(conf);
todo_.pop_front();
if (GetValue(conf) == nullptr) {
MS_LOG(INFO) << conf->node()->ToString() << " not in globle cache.";
@ -193,11 +196,14 @@ void AnalysisResultCacheMgr::Todo() {
MS_LOG(INFO) << conf->node()->ToString() << " not in switch cache";
continue;
}
if (!(*GetValue(conf)->abstract() == *TryGetSwitchValue(conf))) {
auto switch_value = TryGetSwitchValue(conf);
auto abstract = GetValue(conf)->abstract();
MS_EXCEPTION_IF_NULL(switch_value);
MS_EXCEPTION_IF_NULL(abstract);
if (!(*abstract == *switch_value)) {
MS_LOG(WARNING) << " Switch Value is not eq. "
<< " switch cache: " << TryGetSwitchValue(conf)->ToString()
<< " globle cache: " << GetValue(conf)->abstract()->ToString()
<< "\tconf: " << conf->node()->ToString();
<< " switchCache: " << switch_value->ToString() << " globleCache: " << abstract->ToString()
<< "\t\tConf: " << conf->ToString();
}
}
}

View File

@ -41,6 +41,7 @@ using RefInputs = OrderedMap<AnfNodePtr, std::vector<size_t>>;
// Add or get a monad parameter.
AnfNodePtr AddMonadParameter(const FuncGraphPtr &func_graph, const std::string &name,
const abstract::AbstractBasePtr &abs) {
MS_EXCEPTION_IF_NULL(func_graph);
size_t params_size = func_graph->parameters().size();
size_t io_monad_location = params_size;
// Search for existed parameters, return it if found.
@ -109,6 +110,7 @@ bool HasAbstractRef(const AnfNodePtr &node) {
// Gets ref inputs and its indexes from a cnode.
RefInputs GetRefInputs(const CNodePtr &cnode) {
RefInputs ref_inputs;
MS_EXCEPTION_IF_NULL(cnode);
for (size_t i = 1; i < cnode->size(); ++i) {
auto &input = cnode->inputs().at(i);
if (HasAbstractRef(input)) {
@ -231,6 +233,7 @@ class SccFinder {
// Search SCCs from the given graph.
const State &Search(FuncGraphPtr graph) {
// Create graph state, set it as visited.
MS_EXCEPTION_IF_NULL(graph);
auto [inserted, ok] = visited_.emplace(graph, State(index_++));
if (!ok) {
MS_LOG(EXCEPTION) << "Already visited: " << graph->ToString();
@ -336,6 +339,7 @@ class SideEffectFinder {
}
static void UpdateOrderList(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
OrderedSet<CNodePtr> new_order_list;
const auto &order_list = func_graph->order_list();
for (auto &cnode : order_list) {
@ -368,11 +372,13 @@ class SideEffectFinder {
// Gets branch graph from a switch cnode at given input index.
FuncGraphPtr GetSwitchBranch(const CNodePtr &cnode, size_t index) {
MS_EXCEPTION_IF_NULL(cnode);
return GetValueNode<FuncGraphPtr>(cnode->inputs().at(index));
}
// Gets branch graphs from a switch cnode.
std::vector<FuncGraphPtr> GetSwitchBranches(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
constexpr size_t switch_cnode_size = 4;
constexpr size_t true_index = 2;
constexpr size_t false_index = 3;
@ -457,6 +463,7 @@ class SideEffectFinder {
// Gets branch graphs from a switch_layer cnode.
std::vector<FuncGraphPtr> GetSwitchLayerBranches(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
constexpr size_t func_tuple_index = 2;
if (cnode->size() <= func_tuple_index) {
MS_LOG(EXCEPTION) << "Invalid switch_layer: " << cnode->DebugString(2);
@ -485,11 +492,12 @@ class SideEffectFinder {
if (func_graph != nullptr) {
return GetGraphsFromTuple(func_graph->output());
}
MS_LOG(EXCEPTION) << "Invalid input for switch_layer: " << func_tuple->DebugString(2);
MS_LOG(EXCEPTION) << "Invalid input for switch_layer: func_graph is nullptr.";
}
// Get graphs from a tuple of funcs make node for switch_layer.
std::vector<FuncGraphPtr> GetGraphsFromMakeTuple(const CNodePtr &make_tuple) {
MS_EXCEPTION_IF_NULL(make_tuple);
auto &inputs = make_tuple->inputs();
if (inputs.size() <= 1) {
MS_LOG(EXCEPTION) << "Invalid make_tuple for switch_layer: " << make_tuple->DebugString(2);
@ -531,6 +539,7 @@ class SideEffectFinder {
}
EffectInfo TraceTupleEffectInfo(const AnfNodePtr &tuple_node, std::stack<int64_t> *tuple_indexes) {
MS_EXCEPTION_IF_NULL(tuple_indexes);
auto para = dyn_cast<Parameter>(tuple_node);
if (para != nullptr) {
return TraceTupleParaEffectInfo(para, *tuple_indexes);
@ -540,7 +549,7 @@ class SideEffectFinder {
return TraceTupleCNodeEffectInfo(tuple_cnode, tuple_indexes);
}
// Should not reach here.
MS_LOG(EXCEPTION) << "Side effects untraceable: " << tuple_node->DebugString();
MS_LOG(EXCEPTION) << "Side effects untraceable: tuple_cnode is nullptr.";
}
EffectInfo TraceTupleParaEffectInfo(const ParameterPtr &para, const std::stack<int64_t> &tuple_indexes) {
@ -556,6 +565,7 @@ class SideEffectFinder {
EffectInfo TraceTupleCNodeEffectInfo(const CNodePtr &cnode, std::stack<int64_t> *tuple_indexes) {
MS_EXCEPTION_IF_NULL(tuple_indexes);
MS_EXCEPTION_IF_NULL(cnode);
auto prim = GetPrimitive(cnode);
// Trace MakeTuple.
if (IsPrimitiveEquals(prim, prim::kPrimMakeTuple)) {
@ -640,6 +650,7 @@ class SideEffectFinder {
}
// Set merged effect info to both branches.
for (auto &branch : branches) {
MS_EXCEPTION_IF_NULL(branch);
branch->SetEffectInfo(info);
// Update caller if it is existed.
UpdateBranchCaller(branch);
@ -650,6 +661,7 @@ class SideEffectFinder {
EffectInfo MergeEffectInfo(const std::vector<FuncGraphPtr> &branches) {
EffectInfo info = {EffectInfo::kDetected, false, false, false};
for (auto &branch : branches) {
MS_EXCEPTION_IF_NULL(branch);
EffectInfo branch_info = GetEffectInfo(branch);
info.Merge(branch_info);
}
@ -658,6 +670,7 @@ class SideEffectFinder {
// Trace a cnode for effect info.
EffectInfo TraceEffectInfo(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
auto prim = GetPrimitive(cnode);
if (IsPrimitiveEquals(prim, prim::kPrimSwitch)) {
// Special handling for Switch primitive.
@ -740,7 +753,7 @@ class SideEffectFinder {
}
}
// Something is wrong if we reached here.
MS_LOG(WARNING) << "EffectInfo untraceable: " << node->DebugString(2);
MS_LOG(WARNING) << "EffectInfo untraceable: node is a nullptr.";
return {EffectInfo::kDetected, false, false, false};
}
@ -767,6 +780,7 @@ class SideEffectFinder {
}
void ForEachRealArguments(const ParameterPtr &para, const std::function<void(const AnfNodePtr &)> &handler) {
MS_EXCEPTION_IF_NULL(para);
auto func_graph = para->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
// Find index of the parameter, starts from 0.
@ -785,6 +799,7 @@ class SideEffectFinder {
}
// Caller cnode.
auto cnode = dyn_cast<CNode>(user.first->first);
MS_EXCEPTION_IF_NULL(cnode);
if (cnode && input_index < cnode->size()) {
handler(cnode->input(input_index));
}
@ -793,6 +808,7 @@ class SideEffectFinder {
// For call node, returns effect info of the callee graph.
EffectInfo GetCallEffectInfo(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
constexpr size_t min_call_node_size = 2;
if (cnode->size() < min_call_node_size) {
MS_LOG(EXCEPTION) << "Invalid call node: " << cnode->DebugString();
@ -886,20 +902,23 @@ class SideEffectFinder {
const SccPtr &GetScc(const FuncGraphPtr &func_graph) const {
auto found = scc_map_.find(func_graph);
if (found == scc_map_.end()) {
MS_LOG(EXCEPTION) << "SCC not found for " << func_graph->ToString();
MS_LOG(EXCEPTION) << "SCC not found for " << (func_graph ? func_graph->ToString() : "FG(null)");
}
return found->second;
}
// Set effect info for all member graphs in the SCC.
void SetSccEffectInfo(const SccPtr &scc, const EffectInfo &info) const {
MS_EXCEPTION_IF_NULL(scc);
for (auto &g : *scc) {
MS_EXCEPTION_IF_NULL(g);
g->SetEffectInfo(info);
}
}
// Gets EffectInfo for func graph.
EffectInfo GetEffectInfo(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
const auto &effect_info = func_graph->GetEffectInfo();
if (effect_info.state != EffectInfo::kUnknown) {
// Effect info already set, return it.
@ -907,6 +926,7 @@ class SideEffectFinder {
}
// Get SCC that this graph belongs to.
auto &scc = GetScc(func_graph);
MS_EXCEPTION_IF_NULL(scc);
// To prevent SCC members be visited again, we set effect info
// to 'kDetecting' state before start to check cnodes.
EffectInfo info{EffectInfo::kDetecting, false, false, false};
@ -914,6 +934,7 @@ class SideEffectFinder {
// Check side effects for all cnodes in the SCC.
std::vector<CNodePtr> undetected;
for (auto &g : *scc) {
MS_EXCEPTION_IF_NULL(g);
for (auto &cnode : g->order_list()) {
auto cnode_effect = GetEffectInfo(cnode);
if (cnode_effect.state != EffectInfo::kDetected) {
@ -935,6 +956,7 @@ class SideEffectFinder {
SetSccEffectInfo(scc, info);
// Check undetected cnodes again after side effect of the SCC is detected.
for (auto &cnode : undetected) {
MS_EXCEPTION_IF_NULL(cnode);
auto cnode_effect = GetEffectInfo(cnode);
// Side effect should be detected now.
if (cnode_effect.state != EffectInfo::kDetected) {
@ -951,6 +973,8 @@ class SideEffectFinder {
}
void SaveBranchCaller(const CNodePtr &switch_node, const FuncGraphPtr &branch) {
MS_EXCEPTION_IF_NULL(branch);
MS_EXCEPTION_IF_NULL(switch_node);
auto manager = branch->manager();
MS_EXCEPTION_IF_NULL(manager);
auto &node_users = manager->node_users();
@ -971,6 +995,7 @@ class SideEffectFinder {
}
void UpdateBranchCaller(const FuncGraphPtr &branch) {
MS_EXCEPTION_IF_NULL(branch);
auto iter = branch_caller_map.find(branch);
if (iter == branch_caller_map.end()) {
return;
@ -992,6 +1017,8 @@ class SideEffectFinder {
}
void AddMonadArgument(const CNodePtr &cnode, const ValuePtr &monad) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(monad);
auto monad_abs = monad->ToAbstract();
for (size_t i = 1; i < cnode->size(); ++i) {
auto abs = cnode->inputs().at(i)->abstract();
@ -1077,6 +1104,7 @@ class AutoMonadConverter {
// Gets effect info for a cnode.
const EffectInfo &GetEffectInfo(const CNodePtr &cnode) const {
MS_EXCEPTION_IF_NULL(cnode);
auto &effect_info = cnode->GetEffectInfo();
if (effect_info.state != EffectInfo::kDetected) {
// Effect info should have been set by SideEffectFinder.
@ -1135,6 +1163,7 @@ class AutoMonadConverter {
}
void HandleOuterNode(const CNodePtr &cnode, const EffectInfo &info) {
MS_EXCEPTION_IF_NULL(cnode);
if (info.memory || info.load) {
(void)GetUniverse();
bool load_with_primitive = (info.load && IsPrimitiveCNode(cnode));
@ -1187,6 +1216,7 @@ class AutoMonadConverter {
}
void HandleLoad(const CNodePtr &cnode, bool update_state) {
MS_EXCEPTION_IF_NULL(cnode);
auto value = GetValueNode(cnode->input(0));
if (value && value->isa<Primitive>()) {
// For primitive calls that use Ref as input, insert Loads before them.
@ -1273,6 +1303,7 @@ class AutoMonadConverter {
// Add or replace monad input.
void AddMonadInput(const CNodePtr &cnode, const AnfNodePtr &monad) {
MS_EXCEPTION_IF_NULL(cnode);
constexpr size_t max_monad_inputs = 2;
auto monad_abs = monad->abstract();
auto &inputs = cnode->inputs();
@ -1332,6 +1363,7 @@ class AutoMonadConverter {
}
AnfNodePtr UpdateState(const AnfNodePtr &state, const AnfNodePtr &attach) {
MS_EXCEPTION_IF_NULL(attach);
// Not attach UpdateState if set kAttrIgnoreSideEffect.
auto attr_ignore_side_effect = attach->cast<CNodePtr>()->GetAttr(kAttrIgnoreSideEffect);
auto ignore_side_effect = attr_ignore_side_effect != nullptr && attr_ignore_side_effect->isa<BoolImm>() &&
@ -1453,6 +1485,7 @@ bool AutoMonad(const FuncGraphPtr &func_graph) {
bool ReAutoMonad(const FuncGraphPtr &func_graph) {
// AutoMonad for bprop network, only Monad for func graphs which back propogators have side effects.
// Or AutoMonad for MultitypeFuncGraph which specialized in Renormalize other than the first Specialize pass.
MS_EXCEPTION_IF_NULL(func_graph);
bool need_auto_monad = false;
std::vector<FuncGraphPtr> auto_monaded_fg;
func_graph->EraseUnusedNodeInOrder();

View File

@ -38,7 +38,8 @@ string EvalEntryLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList
ss << "Evaluator " << evaluator->ToString() << " run for " << out_conf->node()->scope()->name();
}
for (size_t i = 0; i < arg_spec_list.size(); i++) {
ss << evaluator->ToString() << " input[" << i << "] abstract value: " << arg_spec_list[i]->ToString();
ss << evaluator->ToString() << " input[" << i
<< "] abstract value: " << (arg_spec_list[i] ? arg_spec_list[i]->ToString() : "null abstract.");
}
return ss.str();
}
@ -60,6 +61,9 @@ void EvalFailLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList &,
void BaseFuncGraphEvaluator::EnterStackFrame(const AnalysisEnginePtr &engine, const StackFramePtr &current_stack_frame,
const StackFramePtr &new_stack_frame) {
MS_EXCEPTION_IF_NULL(current_stack_frame);
MS_EXCEPTION_IF_NULL(new_stack_frame);
MS_EXCEPTION_IF_NULL(engine);
// Enter new func graph.
auto &current_node = current_stack_frame->CurrentNode();
auto current_context = current_stack_frame->current_context();
@ -83,8 +87,8 @@ void BaseFuncGraphEvaluator::EnterStackFrame(const AnalysisEnginePtr &engine, co
<< "), enter, function call depth: " << FunctionCallDepth() << " - " << StackFrameDepth();
}
void BaseFuncGraphEvaluator::LeaveStackFrame(const AnalysisEnginePtr &engine,
const StackFramePtr &current_stack_frame) {
void BaseFuncGraphEvaluator::LeaveStackFrame(const AnalysisEnginePtr &, const StackFramePtr &current_stack_frame) {
MS_EXCEPTION_IF_NULL(current_stack_frame);
// Leave current func graph.
auto current_context = current_stack_frame->current_context();
trace::TraceGraphEvalLeave(current_context);
@ -149,8 +153,11 @@ AbstractBasePtr BaseFuncGraphEvaluator::LaunchStackFrame(const AnalysisEnginePtr
AbstractBasePtr BaseFuncGraphEvaluator::LaunchRecursiveEval(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg,
const AnalysisContextPtr &context) {
MS_EXCEPTION_IF_NULL(fg);
MS_EXCEPTION_IF_NULL(engine);
const AnfNodePtr &func_node = fg->get_return();
const auto &all_nodes = TopoSort(func_node, SuccIncoming, [](const AnfNodePtr &node) -> IncludeType {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<ValueNode>() || node->isa<Parameter>()) {
return EXCLUDE;
}
@ -162,7 +169,9 @@ AbstractBasePtr BaseFuncGraphEvaluator::LaunchRecursiveEval(const AnalysisEngine
MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg << "/" << fg->ToString()
<< ", node_conf: " << node_conf->ToString();
auto node_eval_result = engine->ObtainEvalResultWithCache(node_conf);
MS_EXCEPTION_IF_NULL(node_eval_result);
res_base = node_eval_result->abstract();
MS_EXCEPTION_IF_NULL(res_base);
MS_LOG(DEBUG) << GetInferThread() << "Eval ( " << node_conf->ToString() << ") = " << res_base->ToString();
}
MS_EXCEPTION_IF_NULL(res_base);
@ -254,6 +263,7 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
}
void BroadenArgs(const AbstractBasePtrList &args_spec_list, AbstractBasePtrList *broaded_args) {
MS_EXCEPTION_IF_NULL(broaded_args);
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(*broaded_args),
[](const AbstractBasePtr &arg) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(arg);
@ -467,6 +477,7 @@ EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &arg
// Call the original evaluator, get the result: y = f(x)
EvalResultPtr result = evaluator_->Run(engine, args_conf_list, nullptr);
MS_EXCEPTION_IF_NULL(result);
// Build a virtual function: bprop_f which use sense of y as input, return sense of function free variable and input
// parameters. (sense_f, sense_x, ...)(*bpro_f) (sense_y)
AbstractBasePtrList bparams;

View File

@ -228,7 +228,7 @@ class BaseFuncGraphEvaluator : public Evaluator {
const AnalysisContextPtr &context);
static void EnterStackFrame(const AnalysisEnginePtr &engine, const StackFramePtr &current_stack_frame,
const StackFramePtr &new_stack_frame);
static void LeaveStackFrame(const AnalysisEnginePtr &engine, const StackFramePtr &current_stack_frame);
static void LeaveStackFrame(const AnalysisEnginePtr &, const StackFramePtr &current_stack_frame);
};
class FuncGraphEvaluator : public BaseFuncGraphEvaluator {

View File

@ -199,6 +199,7 @@ class OrderEnforcer {
}
void EnforceOrderForOtherCNode(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
// Find refs from the cnode inputs.
auto &inputs = cnode->inputs();
const size_t last_index = inputs.size() - 1;
@ -232,6 +233,7 @@ class OrderEnforcer {
}
bool IsInUpdateState(const AnfNodePtr &load_user, const CNodePtr &update_state) {
MS_EXCEPTION_IF_NULL(update_state);
const size_t attach_index = 2;
const size_t input_size = update_state->inputs().size();
for (size_t index = attach_index; index < input_size; index++) {
@ -368,6 +370,7 @@ class OrderEnforcer {
// Enforce order of execution for Load users node.
void OrderEnforce(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
OrderEnforcer enforcer(func_graph);
enforcer.Run();
auto fg_used_total = func_graph->func_graphs_used_total();

View File

@ -52,10 +52,17 @@ std::unordered_set<std::string> prims_to_skip_undetermined_infer{
EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
const AnfNodeConfigPtr &out_conf) {
MS_EXCEPTION_IF_NULL(engine);
MS_EXCEPTION_IF_NULL(out_conf);
AbstractBasePtrList args_spec_list;
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &ref) -> AbstractBasePtr { return ref->ObtainEvalResult()->abstract(); });
[](const ConfigPtr &ref) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(ref);
MS_EXCEPTION_IF_NULL(ref->ObtainEvalResult());
return ref->ObtainEvalResult()->abstract();
});
auto do_signature = prim_->cast<prim::DoSignaturePrimitivePtr>();
MS_EXCEPTION_IF_NULL(do_signature);
auto &func = do_signature->function();
if (func->isa<Primitive>()) {
auto sig_prim = func->cast<PrimitivePtr>();
@ -67,17 +74,16 @@ EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
}
}
}
MS_EXCEPTION_IF_NULL(out_conf);
if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
}
auto out_node = dyn_cast<CNode>(out_conf->node());
MS_EXCEPTION_IF_NULL(out_node);
const auto &out_node_inputs = out_node->inputs();
if (out_node->inputs().size() == 0 || (out_node_inputs.size() - 1) != args_conf_list.size()) {
MS_LOG(EXCEPTION) << "Op: " << do_signature->function()->ToString()
<< " args size should equal to inputs size minus 1, but args size " << args_conf_list.size()
<< ", inputs size " << out_node_inputs.size();
MS_LOG(EXCEPTION) << "Op: " << func->ToString() << " args size should equal to inputs size minus 1, but args size "
<< args_conf_list.size() << ", inputs size " << out_node_inputs.size();
}
AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()};
@ -90,11 +96,9 @@ EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
AnfNodePtr new_node = nullptr;
if (bound_node() != nullptr) {
TraceGuard trace_guard(std::make_shared<TraceDoSignature>(bound_node()->debug_info()));
new_node = prim::GenerateCNode(out_node->func_graph(), prim_->ToString(), do_signature->function(), args_spec_list,
args_inputs);
new_node = prim::GenerateCNode(out_node->func_graph(), prim_->ToString(), func, args_spec_list, args_inputs);
} else {
new_node = prim::GenerateCNode(out_node->func_graph(), prim_->ToString(), do_signature->function(), args_spec_list,
args_inputs);
new_node = prim::GenerateCNode(out_node->func_graph(), prim_->ToString(), func, args_spec_list, args_inputs);
}
AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_node, out_conf->context(), out_conf->func_graph());
@ -137,12 +141,17 @@ static AbstractBasePtrList GetUnpackGraphSpecArgsList(AbstractBasePtrList args_s
EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
const AnfNodeConfigPtr &out_conf) {
MS_EXCEPTION_IF_NULL(engine);
MS_EXCEPTION_IF_NULL(out_conf);
MS_EXCEPTION_IF_NULL(out_conf->node());
if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
}
auto unpack_graph = prim_->cast<prim::UnpackGraphPrimitivePtr>();
MS_EXCEPTION_IF_NULL(unpack_graph);
auto out_node = out_conf->node()->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(out_node);
const auto &out_node_inputs = out_node->inputs();
if (out_node->inputs().empty() || (out_node_inputs.size() - 1) != args_conf_list.size()) {
MS_LOG(EXCEPTION) << "UnpackGraphPrimitive"
@ -152,8 +161,15 @@ EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()};
AbstractBasePtrList args_spec_list;
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &ref) -> AbstractBasePtr { return ref->ObtainEvalResult()->abstract(); });
[](const ConfigPtr &ref) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(ref);
MS_EXCEPTION_IF_NULL(ref->ObtainEvalResult());
return ref->ObtainEvalResult()->abstract();
});
// get the forward graph
if (args_spec_list.empty()) {
MS_LOG(EXCEPTION) << "args_spec_list can't be empty.";
}
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
auto fn = args_spec_list[0]->cast<AbstractFunctionPtr>();
if (fn == nullptr) {
@ -165,7 +181,6 @@ EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
MS_EXCEPTION_IF_NULL(forward_graph);
AbstractBasePtrList graph_specialize_args =
GetUnpackGraphSpecArgsList(args_spec_list, unpack_graph->need_unpack_args());
AbstractBasePtrList graph_specialize_args_without_sens;
if (unpack_graph->with_sens_in_args() && graph_specialize_args.empty()) {
MS_EXCEPTION(ValueError) << "Grad with sens, but the sens is not provided.";
@ -188,6 +203,8 @@ EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
AnfNodePtr MixedPrecisionCastHelper(const AnfNodePtr &source_node, const AbstractBasePtr &node_type,
const AnfNodePtr &target_type, const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(node_type);
MS_EXCEPTION_IF_NULL(func_graph);
AnfNodePtr target_node = source_node;
if (node_type->isa<AbstractTensor>()) {
auto x = node_type->cast<AbstractTensorPtr>();
@ -239,12 +256,14 @@ AnfNodePtr MixedPrecisionCastHelper(const AnfNodePtr &source_node, const Abstrac
EvalResultPtr MixedPrecisionCastEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
const AnfNodeConfigPtr &out_conf) {
MS_EXCEPTION_IF_NULL(engine);
AbstractBasePtrList args_spec_list;
MS_EXCEPTION_IF_NULL(out_conf);
if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
}
auto out_node = out_conf->node()->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(out_node);
const auto &out_node_inputs = out_node->inputs();
if (out_node->inputs().empty() || (out_node_inputs.size() - 1) != args_conf_list.size()) {
MS_LOG(EXCEPTION) << "MixedPrecisionCast"
@ -258,8 +277,12 @@ EvalResultPtr MixedPrecisionCastEvaluator::Run(AnalysisEnginePtr engine, const C
scope = out_conf->node()->scope();
ScopeGuard scope_guard(scope);
FuncGraphPtr func_graph = out_conf->node()->func_graph();
FuncGraphPtr func_graph = out_node->func_graph();
constexpr size_t source_node_index = 2;
if (out_node_inputs.size() <= source_node_index) {
MS_LOG(EXCEPTION) << "Input size:" << out_node_inputs.size() << " should bigger than 2.";
}
AnfNodePtr new_node =
MixedPrecisionCastHelper(out_node_inputs[source_node_index], args_spec_list[1], out_node_inputs[1], func_graph);
AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_node, out_conf->context(), out_conf->func_graph());
@ -282,6 +305,7 @@ py::object BuildValue(const ValuePtr &value_ptr) {
py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base) {
auto arg_tuple = dyn_cast<AbstractTuple>(abs_base);
MS_EXCEPTION_IF_NULL(arg_tuple);
size_t len = arg_tuple->size();
py::tuple shape_tuple(len);
py::tuple dtype_tuple(len);
@ -317,6 +341,7 @@ py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base) {
auto dic = py::dict();
dic[ATTR_SHAPE] = shape_tuple;
dic[ATTR_DTYPE] = dtype_tuple;
MS_EXCEPTION_IF_NULL(arg_tuple->BuildValue());
if (arg_tuple->BuildValue()->isa<AnyValue>()) {
dic[ATTR_VALUE] = py::none();
} else {
@ -337,6 +362,7 @@ py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base) {
py::dict AbstractListToPython(const AbstractBasePtr &abs_base) {
auto arg_list = dyn_cast<AbstractList>(abs_base);
MS_EXCEPTION_IF_NULL(arg_list);
size_t len = arg_list->size();
py::list shape_list(len);
py::list dtype_list(len);
@ -361,6 +387,7 @@ py::dict AbstractListToPython(const AbstractBasePtr &abs_base) {
auto dic = py::dict();
dic[ATTR_SHAPE] = shape_list;
dic[ATTR_DTYPE] = dtype_list;
MS_EXCEPTION_IF_NULL(arg_list->BuildValue());
if (arg_list->BuildValue()->isa<AnyValue>()) {
dic[ATTR_VALUE] = py::none();
} else {
@ -377,6 +404,9 @@ py::dict AbstractListToPython(const AbstractBasePtr &abs_base) {
void ConvertAbstractTensorToPython(const AbstractBasePtr &abs_base, py::dict *dic) {
auto arg_tensor = dyn_cast<AbstractTensor>(abs_base);
MS_EXCEPTION_IF_NULL(dic);
MS_EXCEPTION_IF_NULL(arg_tensor);
MS_EXCEPTION_IF_NULL(arg_tensor->shape());
(*dic)[ATTR_SHAPE] = arg_tensor->shape()->shape();
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
const auto &min_shape = arg_tensor->shape()->min_shape();
@ -399,12 +429,15 @@ void ConvertAbstractTensorToPython(const AbstractBasePtr &abs_base, py::dict *di
}
void ConvertAbstractFunctionToPython(const AbstractBasePtr &abs_base, py::dict *dic) {
MS_EXCEPTION_IF_NULL(dic);
MS_EXCEPTION_IF_NULL(abs_base);
(*dic)[ATTR_SHAPE] = py::none();
(*dic)[ATTR_DTYPE] = abs_base->BuildType();
(*dic)[ATTR_VALUE] = py::none();
if (abs_base->isa<PartialAbstractClosure>()) {
AbstractBasePtrList args = abs_base->cast<PartialAbstractClosurePtr>()->args();
if (!args.empty()) {
MS_EXCEPTION_IF_NULL(args[0]->BuildValue());
auto value = args[0]->BuildValue()->cast<parse::ClassTypePtr>();
if (value != nullptr) {
(*dic)[ATTR_DTYPE] = std::make_shared<TypeType>();
@ -466,6 +499,7 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
dic[ATTR_VALUE] = py::none();
} else {
auto value = abs_base->BuildValue();
MS_EXCEPTION_IF_NULL(value);
if ((*value == *kAnyValue)) {
auto value_desc = abs_base->value_desc();
MS_EXCEPTION(TypeError) << "Unsupported parameter " << (value_desc.empty() ? "type" : value_desc)
@ -490,6 +524,8 @@ py::tuple PreparePyInputs(const PrimitivePyPtr &prim_py, const AbstractBasePtrLi
}
void CheckCustomPrimOutputInferResult(const PrimitivePtr &prim, const AbstractBasePtr &res_spec) {
MS_EXCEPTION_IF_NULL(prim);
MS_EXCEPTION_IF_NULL(res_spec);
const string kOutputNum = "output_num";
if (prim->IsCustomPrim()) {
// Raise error if output_num is not match the infer result.
@ -609,8 +645,10 @@ EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, c
MS_EXCEPTION_IF_NULL(context);
bool need_infer_value =
(!(eval_impl_.in_white_list_) || (context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode)) &&
std::all_of(args.begin(), args.end(),
[](const AbstractBasePtr &abs) -> bool { return (abs->BuildValue() != nullptr); });
std::all_of(args.begin(), args.end(), [](const AbstractBasePtr &abs) -> bool {
MS_EXCEPTION_IF_NULL(abs);
return (abs->BuildValue() != nullptr);
});
AbstractBasePtr abs_base = nullptr;
ValuePtr value = nullptr;
prim_->BeginRecordAddAttr();
@ -684,8 +722,16 @@ EvalResultPtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const Ab
TypePtrList selections;
MS_EXCEPTION_IF_NULL(item.second);
(void)std::transform(item.second->begin(), item.second->end(), std::back_inserter(selections),
[&args](size_t arg_idx) -> TypePtr { return args[arg_idx]->GetTypeTrack(); });
[&args](size_t arg_idx) -> TypePtr {
if (arg_idx >= args.size()) {
MS_LOG(EXCEPTION) << "Index:" << arg_idx << " out of range:" << args.size();
}
MS_EXCEPTION_IF_NULL(args[arg_idx]);
return args[arg_idx]->GetTypeTrack();
});
TypePtr res = CheckTypeList(item.first, selections);
MS_EXCEPTION_IF_NULL(return_value_type_);
MS_EXCEPTION_IF_NULL(item.first);
if (*return_value_type_ == *(item.first)) {
ret_value_type = res;
}
@ -806,6 +852,7 @@ EvalResultPtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_
input.push_back(conf->node());
MS_EXCEPTION_IF_NULL(old_conf);
FuncGraphPtr func_graph = old_conf->node()->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
CNodePtr new_cnode = func_graph->NewCNode(input);
if (require_type == REQUIRE_TYPE::ATTR) {
new_cnode = func_graph->NewCNode({new_cnode});
@ -829,11 +876,13 @@ EvalResultPtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &engin
MS_LOG(DEBUG) << "Args[0]: " << args_spec_list[0]->ToString();
MS_LOG(DEBUG) << "Args[1]: " << args_spec_list[1]->ToString();
auto data_v = args_spec_list[0]->BuildValue();
MS_EXCEPTION_IF_NULL(data_v);
if (!data_v->isa<parse::NameSpace>()) {
MS_LOG(EXCEPTION) << "Data is not NameSpace : " << data_v->ToString();
}
auto item_v = args_spec_list[1]->BuildValue();
MS_EXCEPTION_IF_NULL(item_v);
if (item_v->isa<StringImm>()) {
item_v = std::make_shared<parse::Symbol>(item_v->cast<StringImmPtr>()->value());
}
@ -845,9 +894,10 @@ EvalResultPtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &engin
// item_name to func addr from obj_map
parse::SymbolPtr symbol = item_v->cast<parse::SymbolPtr>();
parse::NameSpacePtr name_space = data_v->cast<parse::NameSpacePtr>();
MS_EXCEPTION_IF_NULL(out_conf);
auto out_node = out_conf->node();
FuncGraphPtr func_graph = out_node->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
auto new_node = parse::ResolveSymbol(func_graph->manager(), name_space, symbol, out_node);
if (new_node == nullptr) {
MS_LOG(EXCEPTION) << "Resolve node failed";
@ -857,6 +907,7 @@ EvalResultPtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &engin
func_graph->ReplaceInOrder(out_node, new_node);
AnalysisEnginePtr eng = out_conf->engine();
MS_EXCEPTION_IF_NULL(eng);
AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_node, out_conf->context(), out_conf->func_graph());
return eng->ForwardConfig(out_conf, fn_conf);
}
@ -877,7 +928,7 @@ EvalResultPtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &eng
std::string item_name = item_v->cast<StringImmPtr>()->value();
MS_LOG(DEBUG) << "Resolve name: " << cls->tag().name();
MS_LOG(DEBUG) << "Resolve item: " << item_name;
MS_EXCEPTION_IF_NULL(cls);
AbstractBasePtr attr = cls->GetAttribute(item_name);
if (attr != nullptr) {
return std::make_shared<EvalResult>(attr, nullptr);
@ -885,6 +936,8 @@ EvalResultPtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &eng
ValuePtr method = cls->GetMethod(item_name);
if (method->isa<AnyValue>()) {
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
MS_EXCEPTION_IF_NULL(args_spec_list[0]->BuildType());
MS_EXCEPTION(AttributeError) << "Unknown field, data type: " << args_spec_list[0]->BuildType()->ToString()
<< ", item value: " << item_v->ToString();
}
@ -920,6 +973,7 @@ EvalResultPtr GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePt
if (require.is<std::string>()) {
// composite registered in standard_method_map go to this branch
converted_v = prim::GetPythonOps(require.cast<std::string>());
MS_EXCEPTION_IF_NULL(converted_v);
if (!converted_v->isa<Primitive>()) {
AddToManager(engine, converted_v->cast<FuncGraphPtr>());
}
@ -945,6 +999,7 @@ EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePt
scope = out_conf->node()->scope();
}
ScopeGuard scope_guard(scope);
MS_EXCEPTION_IF_NULL(item_value);
if (item_value->isa<AnyValue>()) {
MS_LOG(EXCEPTION) << "The value of the attribute could not be inferred: " << item_value->ToString();
}
@ -973,7 +1028,7 @@ class EmbedEvaluator : public SymbolicPrimEvaluator {
}
AnfNodeConfigPtr node_conf = dyn_cast<AnfNodeConfig>(args_conf_list[0]);
MS_EXCEPTION_IF_NULL(node_conf);
MS_EXCEPTION_IF_NULL(node_conf->ObtainEvalResult());
AbstractBasePtr x = node_conf->ObtainEvalResult()->abstract();
x = SensitivityTransform(x);
SymbolicKeyInstancePtr key = std::make_shared<SymbolicKeyInstance>(node_conf->node(), x);
@ -983,12 +1038,12 @@ class EmbedEvaluator : public SymbolicPrimEvaluator {
};
static AnfNodePtr FindParameterNodeByString(const FuncGraphManagerPtr &manager, const std::string &name) {
MS_EXCEPTION_IF_NULL(manager);
auto root_g_set = manager->roots();
if (root_g_set.size() != 1) {
return nullptr;
}
const FuncGraphPtr &root_g = root_g_set.back();
for (auto &param_node : root_g->parameters()) {
auto param = param_node->cast<ParameterPtr>();
if (param && name == param->name()) {
@ -1014,8 +1069,9 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
MS_LOG(ERROR) << "Conf should be AnfNodeConfig";
return nullptr;
}
MS_EXCEPTION_IF_NULL(node_conf->ObtainEvalResult());
AbstractBasePtr abs = node_conf->ObtainEvalResult()->abstract();
MS_EXCEPTION_IF_NULL(abs);
AbstractRefPtr ref_abs = abs->cast<AbstractRefPtr>();
if (ref_abs == nullptr) {
MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref, but " << abs->ToString();
@ -1126,6 +1182,7 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
// Get the type parameter.
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
TypePtr type = args_spec_list[0]->GetTypeTrack();
MS_EXCEPTION_IF_NULL(type);
if (type->type_id() != kMetaTypeTypeType) {
MS_LOG(EXCEPTION) << "CreateInstanceEvaluator require first parameter should be an object of TypeType, but got "
<< type->ToString();
@ -1179,6 +1236,9 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
// Exclude class type by minus 1;
std::size_t params_size = args_spec_list.size() - 1;
auto params = py::tuple(params_size);
if (params_size > params.size()) {
MS_LOG(EXCEPTION) << "Unexpected params_size:" << params_size << ",params.size():" << params.size();
}
if (params_size > 0) {
for (size_t i = 0; i < params_size; i++) {
// Only support the Scalar parameters type. Bypass class type by offset with 1.
@ -1209,9 +1269,11 @@ class PartialEvaluator : public Evaluator {
MS_EXCEPTION_IF_NULL(args_conf_list[0]);
MS_EXCEPTION_IF_NULL(args_conf_list[0]->ObtainEvalResult());
auto arg0_value = args_conf_list[0]->ObtainEvalResult()->abstract();
MS_EXCEPTION_IF_NULL(arg0_value);
AbstractBasePtrList args_spec_list{arg0_value};
// Func in hypermap(partial(Func, arg0), arg1, arg2) may become Poly Node.
if (arg0_value->isa<AbstractError>()) {
MS_EXCEPTION_IF_NULL(arg0_value->GetValueTrack());
auto ret = std::make_shared<AbstractError>(arg0_value->GetValueTrack()->cast<StringImmPtr>(), out_conf->node());
MS_LOG(DEBUG) << "AbstractError for node: " << out_conf->node()->DebugString()
<< " as func is: " << arg0_value->ToString();
@ -1259,7 +1321,8 @@ class PartialEvaluator : public Evaluator {
}
EvalResultPtr HandleDoSignature(const AnalysisEnginePtr &engine, const ValuePtr &signature_value,
const AnfNodeConfigPtr &out_conf = nullptr) const {
const AnfNodeConfigPtr &out_conf) const {
MS_EXCEPTION_IF_NULL(engine);
MS_EXCEPTION_IF_NULL(out_conf);
MS_EXCEPTION_IF_NULL(out_conf->node());
auto cnode = out_conf->node()->cast<CNodePtr>();
@ -1273,7 +1336,7 @@ class PartialEvaluator : public Evaluator {
ScopePtr scope = out_conf->node()->scope();
ScopeGuard scope_guard(scope);
MS_EXCEPTION_IF_NULL(func_graph);
CNodePtr new_cnode = func_graph->NewCNode(new_nodes_inputs);
AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_cnode, out_conf->context(), out_conf->func_graph());
return engine->ForwardConfig(out_conf, fn_conf);
@ -1456,7 +1519,7 @@ bool IsSubtypeClass(const AbstractBasePtr x, const TypePtr model) {
if (model->IsGeneric()) {
return true;
}
MS_EXCEPTION_IF_NULL(model_class);
if (x_class->tag() == model_class->tag()) {
auto m_attributes = model_class->GetAttributes();
auto x_attributes = x_class->attributes();

View File

@ -32,13 +32,16 @@ namespace mindspore {
namespace abstract {
namespace {
inline AbstractBasePtr GetEvaluatedValue(const AnfNodeConfigPtr &conf) {
MS_EXCEPTION_IF_NULL(conf);
if (conf->node()->intermediate_abstract()) {
return conf->node()->intermediate_abstract();
}
MS_EXCEPTION_IF_NULL(conf->ObtainEvalResult());
return conf->ObtainEvalResult()->abstract();
}
AnfNodePtr BuildValueNode(const ValuePtr &v, const AbstractBasePtr &abs_base) {
MS_EXCEPTION_IF_NULL(abs_base);
AnfNodePtr value_node = NewValueNode(v);
value_node->set_abstract(abs_base);
MS_LOG(DEBUG) << "Create ValueNode: " << value_node->ToString() << ", with abstract: " << abs_base->ToString();
@ -53,6 +56,7 @@ bool IsVisible(FuncGraphPtr fg, const FuncGraphPtr &parent) {
}
bool CheckAbstractTensor(const AbstractBasePtr &abs_base) {
MS_EXCEPTION_IF_NULL(abs_base);
if (abs_base->isa<AbstractTensor>()) {
return true;
} else if (abs_base->isa<AbstractSequeue>()) {
@ -69,7 +73,8 @@ bool CheckAbstractTensor(const AbstractBasePtr &abs_base) {
FuncGraphPtr ProgramSpecializer::Run(const FuncGraphPtr &fg, const AnalysisContextPtr &context) {
MS_EXCEPTION_IF_NULL(fg);
MS_EXCEPTION_IF_NULL(context);
MS_LOG(DEBUG) << "Specialize topmost function graph: " << context->func_graph()->ToString();
MS_LOG(DEBUG) << "Specialize topmost function graph: "
<< (context->func_graph() ? context->func_graph()->ToString() : "FG(Null)");
if (top_context_ == nullptr) {
top_context_ = context;
MS_LOG(INFO) << "Specialize set top func graph context: " << context->ToString();
@ -82,6 +87,7 @@ FuncGraphPtr ProgramSpecializer::SpecializeFuncGraph(const FuncGraphPtr &fg, con
MS_EXCEPTION_IF_NULL(context);
auto iter = specializations_.find(context->SpecializeKey());
if (iter != specializations_.end()) {
MS_EXCEPTION_IF_NULL(iter->second);
return iter->second->specialized_func_graph();
}
@ -132,11 +138,11 @@ AnfNodePtr FuncGraphSpecializer::ReplicateDisconnectedNode(const AnfNodePtr &nod
std::shared_ptr<FuncGraphSpecializer> specializer = GetTopSpecializer(node);
// If had replicated, just return that.
MS_EXCEPTION_IF_NULL(specializer->repl_node_);
auto iter = specializer->repl_node_->find(node);
if (iter != specializer->repl_node_->end()) {
return iter->second;
}
auto new_node = specializer->cloner_->CloneDisconnected(node);
if (node->isa<CNode>()) {
if (!new_node->isa<CNode>()) {
@ -157,6 +163,7 @@ AnfNodePtr FuncGraphSpecializer::ReplicateDisconnectedNode(const AnfNodePtr &nod
}
void FuncGraphSpecializer::UpdateNewCNodeInputs(const AnfNodePtr &node, const AnfNodePtr &new_node) {
MS_EXCEPTION_IF_NULL(node);
auto c_node = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(c_node);
auto inputs = c_node->inputs();
@ -170,6 +177,7 @@ void FuncGraphSpecializer::UpdateNewCNodeInputs(const AnfNodePtr &node, const An
MS_EXCEPTION_IF_NULL(c_inp);
auto c_new_inp = new_inp->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(c_new_inp);
MS_EXCEPTION_IF_NULL(c_new_inp->func_graph());
MS_LOG(DEBUG) << "Replace in order, inp node: " << inp->DebugString() << " -> " << new_inp->DebugString();
c_new_inp->func_graph()->ReplaceInOrder(c_inp, c_new_inp);
}
@ -208,8 +216,9 @@ std::shared_ptr<FuncGraphSpecializer> FuncGraphSpecializer::GetTopSpecializer(co
MS_EXCEPTION_IF_NULL(specializer_->top_context());
if (specializer_->top_context()->func_graph() == fg) { // `fg` is top func graph.
specializer = specializer_->GetFuncGraphSpecializer(specializer_->top_context());
MS_LOG(INFO) << "Used top func graph specializer as parent for " << func_graph_->ToString()
<< ", node: " << node->DebugString() << ", NodeInfo: " << trace::GetDebugInfo(node->debug_info());
MS_LOG(INFO) << "Used top func graph specializer as parent for "
<< (func_graph_ ? func_graph_->ToString() : "FG(Null)") << ", node: " << node->DebugString()
<< ", NodeInfo: " << trace::GetDebugInfo(node->debug_info());
MS_EXCEPTION_IF_NULL(specializer);
break;
}
@ -219,21 +228,28 @@ std::shared_ptr<FuncGraphSpecializer> FuncGraphSpecializer::GetTopSpecializer(co
if (specializer == nullptr) {
MS_LOG(EXCEPTION) << "`specializer` should not be null, node: " << node->DebugString()
<< ", NodeInfo: " << trace::GetDebugInfo(node->debug_info()) << ".\n"
<< func_graph_->ToString() << " has no parent context? At least not " << fg->ToString();
<< (func_graph_ ? func_graph_->ToString() : "FG(Null)")
<< " has no parent context? At least not " << fg->ToString();
}
}
return specializer;
}
void FuncGraphSpecializer::Run() {
MS_LOG(DEBUG) << "Before run, origin func graph name: " << func_graph_->ToString()
<< ", cloned func graph name: " << specialized_func_graph_->ToString()
<< ", func graph: " << func_graph_->get_return()->DebugString();
MS_LOG(DEBUG) << "Before run, origin func graph name: " << (func_graph_ ? func_graph_->ToString() : "FG(Null)")
<< ", cloned func graph name: "
<< (specialized_func_graph_ ? specialized_func_graph_->ToString() : "FG(Null)") << ", func graph: "
<< (func_graph_ ? func_graph_->get_return() ? func_graph_->get_return()->DebugString() : "return null"
: "FG(null)");
FirstPass();
SecondPass();
MS_LOG(DEBUG) << "After run, origin func graph name: " << func_graph_->ToString()
<< ", cloned func graph name: " << specialized_func_graph_->ToString()
<< ", new func graph: " << specialized_func_graph_->get_return()->DebugString();
MS_LOG(DEBUG) << "After run, origin func graph name: " << (func_graph_ ? func_graph_->ToString() : "FG(Null)")
<< ", cloned func graph name: "
<< (specialized_func_graph_ ? specialized_func_graph_->ToString() : "FG(Null)") << ", new func graph: "
<< (specialized_func_graph_ ? specialized_func_graph_->get_return()
? specialized_func_graph_->get_return()->DebugString()
: "return null"
: "FG(null)");
}
void FuncGraphSpecializer::FirstPass() {
@ -291,8 +307,8 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(new_node);
if (new_node->func_graph() != specialized_func_graph_) {
MS_LOG(EXCEPTION) << "Error in specializer [A] node: " << node->DebugString()
<< ", new_node: " << new_node->DebugString()
<< ", new_node->func_graph(): " << new_node->func_graph()->ToString()
<< ", new_node: " << new_node->DebugString() << ", new_node->func_graph(): "
<< (new_node->func_graph() ? new_node->func_graph()->ToString() : "FG(Null)")
<< ", specialized_func_graph_: " << specialized_func_graph_->ToString();
return;
}
@ -310,6 +326,7 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) {
auto attrs = conf->ObtainEvalResult()->attribute();
auto c_old = node->cast<CNodePtr>();
auto c_new = new_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(c_new);
auto new_inputs = c_new->inputs();
auto old_inputs = c_old->inputs();
for (size_t i = 0; i < old_inputs.size(); ++i) {
@ -321,7 +338,6 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) {
AnfNodePtr replace_node = BuildPossibleValueNode(iconf->node(), ival, attrs);
if (replace_node == nullptr) {
replace_node = BuildReplacedNode(iconf);
MS_EXCEPTION_IF_NULL(replace_node);
replace_node->set_abstract(ival);
MS_LOG(DEBUG) << "Set replaced: " << replace_node->ToString() << ", to abstract: " << ival->ToString();
} else {
@ -343,7 +359,7 @@ AnfNodePtr FuncGraphSpecializer::BuildReplacedNode(const AnfNodeConfigPtr &conf)
auto conf_iter = engine_->anfnode_config_map().find(conf);
AnfNodeConfigPtr new_conf = conf;
while (conf_iter != engine_->anfnode_config_map().end()) {
MS_LOG(DEBUG) << "Origin conf: node(" << new_conf->node()->DebugString() << ")";
MS_LOG(DEBUG) << "Origin conf: node(" << (new_conf->node() ? new_conf->node()->DebugString() : "Node(Null)") << ")";
new_conf = conf_iter->second;
MS_EXCEPTION_IF_NULL(new_conf);
const auto &forward_node = new_conf->node();
@ -366,15 +382,17 @@ AnfNodePtr FuncGraphSpecializer::BuildReplacedNode(const AnfNodeConfigPtr &conf)
// CloneOrderlist, and it will be replaced inside ReplicateDisconnectedNode.
// For 2.1 the following code will do the job, replace replicated origin cnode with the replicated
// forward one in the replicated func_graph.
MS_EXCEPTION_IF_NULL(conf_iter->first);
const auto &origin_node = conf_iter->first->node();
const auto &replicated_origin_node = GetReplicatedNode(origin_node);
if (replicated_origin_node != origin_node) {
MS_LOG(DEBUG) << "Replace replicated origin node in order list: " << replicated_origin_node->DebugString()
<< ", with replicated forwarded node: " << replicated_forward_node->DebugString();
MS_EXCEPTION_IF_NULL(replicated_forward_node->func_graph());
replicated_forward_node->func_graph()->ReplaceInOrder(replicated_origin_node, replicated_forward_node);
} else {
MS_LOG(EXCEPTION) << "Origin node is not replicated in specialized func_graph, origin node: "
<< origin_node->DebugString();
<< (origin_node ? origin_node->DebugString() : "Node(Null)");
}
}
conf_iter = engine_->anfnode_config_map().find(new_conf);
@ -406,6 +424,7 @@ inline bool CanSpecializeNode(const AnfNodePtr &node) {
AnfNodePtr FuncGraphSpecializer::BuildSpecializedNode(const AnfNodePtr &node, const AbstractBasePtr &abs,
const AbstractBasePtrList &argvals) {
MS_EXCEPTION_IF_NULL(abs);
MS_EXCEPTION_IF_NULL(node);
AbstractFunctionPtr real_a = dyn_cast<AbstractFunction>(abs);
MS_EXCEPTION_IF_NULL(real_a);
@ -429,9 +448,11 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNode(const AnfNodePtr &node, co
}
// Set the flag, so this MetaFuncGraph will be Re-AutoMonaded.
MS_EXCEPTION_IF_NULL(func);
if (func->isa<MetaFuncGraphAbstractClosure>()) {
auto specialized_fg = GetValueNode<FuncGraphPtr>(repl);
if (specialized_fg != nullptr && (argvals.size() > 1) && argvals[argvals.size() - 1]->isa<AbstractUMonad>()) {
if (specialized_fg != nullptr && (argvals.size() > 1) && argvals.back() != nullptr &&
argvals.back()->isa<AbstractUMonad>()) {
specialized_fg->set_flag(mindspore::kFuncGraphFlagReAutoMonad, true);
}
}
@ -482,6 +503,7 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AnfNodePtr &nod
AnalysisContextPtr context = MakeContext(engine_, real_eval, argvals);
MS_LOG(DEBUG) << "Specialize function graph: " << context->func_graph()->ToString() << ", args: " << argvals.size()
<< ", graph: " << context->func_graph()->get_return()->DebugString();
MS_EXCEPTION_IF_NULL(context->func_graph());
if (context->func_graph()->stub()) {
MS_LOG(DEBUG) << "Specialize stub function graph, return the original node: " << context->func_graph()->ToString()
<< ", args: " << argvals.size() << ", graph: " << context->func_graph()->get_return()->DebugString()
@ -489,6 +511,7 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AnfNodePtr &nod
return node;
}
FuncGraphPtr v = specializer_->SpecializeFuncGraph(context->func_graph(), context);
MS_EXCEPTION_IF_NULL(v);
v->set_flag(kFuncGraphFlagUndetermined, false);
return BuildValueNode(v, abs);
}
@ -505,8 +528,13 @@ AnalysisContextPtr FuncGraphSpecializer::MakeContext(const AnalysisEnginePtr &en
}
AnfNodePtr FuncGraphSpecializer::BuildSpecializedParameterNode(const CNodePtr &new_node) {
MS_EXCEPTION_IF_NULL(new_node);
auto new_inputs = new_node->inputs();
if (new_inputs.empty()) {
MS_LOG(EXCEPTION) << "inputs can't be empty.";
}
AnfNodePtr func = new_inputs[0];
MS_EXCEPTION_IF_NULL(new_inputs[0]);
AbstractBasePtr fnval = new_inputs[0]->abstract();
AbstractBasePtrList args;
@ -549,6 +577,7 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedParameterNode(const CNodePtr &n
partial_node_list.push_back(old_node);
}
}
MS_EXCEPTION_IF_NULL(new_node->func_graph());
wrapped_node = new_node->func_graph()->NewCNode(partial_node_list);
wrapped_node->set_abstract(partial_closure);
}
@ -556,6 +585,7 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedParameterNode(const CNodePtr &n
}
const EvaluatorCacheMgrPtr FuncGraphSpecializer::GetEvalCache(const EvaluatorPtr &eval) {
MS_EXCEPTION_IF_NULL(eval);
auto cache_iter = evalcaches_.find(eval);
if (cache_iter == evalcaches_.end()) {
evalcaches_[eval] = eval->evaluator_cache_mgr();
@ -571,7 +601,11 @@ std::pair<AbstractBasePtrList, AbstractBasePtr> FuncGraphSpecializer::BuildFromB
EvalResultPtr ret = nullptr;
AbstractBasePtrList broaded_argvals;
std::vector<AbstractBasePtrList> args_vector;
auto &origin_eval_cache = evalcaches_[eval]->GetCache();
auto eval_cache_iter = evalcaches_.find(eval);
if (eval_cache_iter == evalcaches_.end()) {
MS_LOG(EXCEPTION) << "Evaluator:" << eval->ToString() << " not exist in cache.";
}
auto &origin_eval_cache = eval_cache_iter->second->GetCache();
for (auto &argvals_map : origin_eval_cache) {
auto argvals = argvals_map.first;
args_vector.push_back(argvals);
@ -667,6 +701,7 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
auto status = FindUniqueArgvals(func_abs, eval, empty_args, &result);
MS_LOG(DEBUG) << "FindUniqueArgvals return status: " << status;
// if a node is a poly node, or an input parameter is a PartialAbstractClosure, expand it early
MS_EXCEPTION_IF_NULL(func->func_graph());
if (status == kSpecializeFindUniqueArgvalPoly ||
(func->isa<Parameter>() && func->func_graph()->has_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER))) {
auto wrapped_node = BuildSpecializedParameterNode(new_node);
@ -697,6 +732,7 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
namespace {
void DumpEvaluatorCache(const EvaluatorCacheMgrPtr &evaluator_cache_mgr, const AbstractBasePtrList &argvals) {
MS_EXCEPTION_IF_NULL(evaluator_cache_mgr);
MS_LOG(DEBUG) << "Find unique argvals failed: " << argvals.size() << ", " << argvals << ". Check cache all items.";
int64_t i = 0;
const EvalResultCache &map = evaluator_cache_mgr->GetCache();
@ -706,6 +742,7 @@ void DumpEvaluatorCache(const EvaluatorCacheMgrPtr &evaluator_cache_mgr, const A
}
bool IsPolyFunc(const AbstractFunctionPtr &func, const AbstractBasePtrList &argvals) {
MS_EXCEPTION_IF_NULL(func);
if (func->isa<PrimitiveAbstractClosure>() && argvals.empty()) {
MS_LOG(DEBUG) << "High order primitive return POLY.";
return true;
@ -733,6 +770,7 @@ SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunct
MS_EXCEPTION_IF_NULL(result);
EvaluatorCacheMgrPtr evaluator_cache_mgr = eval->evaluator_cache_mgr();
MS_EXCEPTION_IF_NULL(evaluator_cache_mgr);
auto data = evaluator_cache_mgr->GetValue(argvals);
if (data != nullptr) {
*result = std::make_pair(argvals, data->abstract());
@ -740,13 +778,16 @@ SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunct
}
DumpEvaluatorCache(evaluator_cache_mgr, argvals);
MS_EXCEPTION_IF_NULL(GetEvalCache(eval));
const EvalResultCache &choices = GetEvalCache(eval)->GetCache();
auto cache = GetEvalCache(eval);
MS_EXCEPTION_IF_NULL(cache);
const EvalResultCache &choices = cache->GetCache();
if (choices.get(argvals) != nullptr) {
*result = std::make_pair(argvals, GetEvalCache(eval)->GetValue(argvals)->abstract());
MS_EXCEPTION_IF_NULL(cache->GetValue(argvals));
*result = std::make_pair(argvals, cache->GetValue(argvals)->abstract());
return kSpecializeSuccess;
} else if (choices.size() == 1) {
MS_LOG(DEBUG) << "Evaluator cache has a single item, just use it.";
MS_EXCEPTION_IF_NULL(choices.begin()->second);
*result = std::make_pair(choices.begin()->first, choices.begin()->second->abstract());
return kSpecializeSuccess;
} else if (choices.empty()) {
@ -768,11 +809,14 @@ SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunct
}
}
static PrimitivePtr BuildPrimtiveValueWithAttributes(const PrimitivePtr &prim, const AttrValueMapPtr &attrs) {
MS_EXCEPTION_IF_NULL(prim);
auto &prim_attrs = prim->attrs();
bool is_attr_same = true;
for (auto &item : *attrs) {
auto itr = prim_attrs.find(item.first);
if (itr != prim_attrs.end()) {
MS_EXCEPTION_IF_NULL(itr->second);
MS_EXCEPTION_IF_NULL(item.second);
if (!(*(itr->second) == *(item.second))) {
is_attr_same = false;
break;

View File

@ -64,6 +64,7 @@ class RemoveMonad {
}
void RemoveMonadFromRandomNodes(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto &inputs = cnode->inputs();
@ -79,6 +80,7 @@ class RemoveMonad {
}
void RemoveRandomNodesFromMonadChain(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
const size_t first_index = 1;

View File

@ -22,11 +22,15 @@ namespace mindspore {
namespace abstract {
AbstractBasePtrList StackFrame::GenerateArgsAbsList(const AnalysisEnginePtr &engine, const EvaluatorPtr &evaluator,
const CNodePtr current_cnode) {
MS_EXCEPTION_IF_NULL(current_cnode);
MS_EXCEPTION_IF_NULL(evaluator);
AbstractBasePtrList args_abs_list;
auto &inputs = current_cnode->inputs();
for (std::size_t i = 1; i < inputs.size(); i++) {
auto config = engine->MakeConfig(inputs[i], current_context_, current_context_->func_graph());
auto abs = config->ObtainEvalResult()->abstract();
auto result = config->ObtainEvalResult();
MS_EXCEPTION_IF_NULL(result);
auto abs = result->abstract();
args_abs_list.push_back(abs);
}
args_abs_list = evaluator->NormalizeArgs(args_abs_list);
@ -36,6 +40,8 @@ AbstractBasePtrList StackFrame::GenerateArgsAbsList(const AnalysisEnginePtr &eng
AnalysisContextPtr StackFrame::GetParentContext(const BaseFuncGraphEvaluatorPtr &fg_evaluator,
const AbstractFunctionPtr &graph_func) {
MS_EXCEPTION_IF_NULL(graph_func);
MS_EXCEPTION_IF_NULL(fg_evaluator);
AnalysisContextPtr parent_context = nullptr;
auto func_graph_abs = dyn_cast<FuncGraphAbstractClosure>(graph_func);
if (func_graph_abs != nullptr) { // Set parent context for FuncGraphAbstractClosure.
@ -55,6 +61,8 @@ AnalysisContextPtr StackFrame::GetParentContext(const BaseFuncGraphEvaluatorPtr
// Inner jump implementation.
StackFramePtr StackFrame::DoJump(const AnalysisEnginePtr &engine, const CNodePtr current_cnode,
const AbstractFunctionPtr &graph_func) {
MS_EXCEPTION_IF_NULL(engine);
MS_EXCEPTION_IF_NULL(current_cnode);
// Get the evaluator for func graph.
auto evaluator = engine->GetEvaluatorFor(graph_func);
auto fg_evaluator = dyn_cast<BaseFuncGraphEvaluator>(evaluator);
@ -102,6 +110,7 @@ StackFramePtr StackFrame::DoJump(const AnalysisEnginePtr &engine, const CNodePtr
// Check if we need branch to another func graph.
StackFramePtr StackFrame::Jump(const AnalysisEnginePtr &engine) {
MS_EXCEPTION_IF_NULL(engine);
auto &current_node = CurrentNode();
if (!current_node->isa<CNode>()) {
return nullptr;
@ -126,19 +135,23 @@ StackFramePtr StackFrame::Jump(const AnalysisEnginePtr &engine) {
// Run one step in current func graph.
EvalResultPtr StackFrame::Step(const AnalysisEnginePtr &engine) {
MS_EXCEPTION_IF_NULL(engine);
auto &current_node = NextNode();
MS_LOG(DEBUG) << "current_node: " << current_node->DebugString()
<< ", current_context_: " << current_context_->ToString();
AnfNodeConfigPtr node_conf = engine->MakeConfig(current_node, current_context_, current_context_->func_graph());
auto node_eval_result = engine->ObtainEvalResultWithCache(node_conf);
MS_LOG(DEBUG) << GetInferThread() << "Eval(" << node_conf->ToString()
<< ") = " << node_eval_result->abstract()->ToString();
MS_LOG(DEBUG) << GetInferThread() << "Eval(" << node_conf->ToString() << ") = "
<< (node_eval_result->abstract() ? node_eval_result->abstract()->ToString() : "Abstract null");
return node_eval_result;
}
// Return back from child func graph.
void StackFrame::Back(const AnalysisEnginePtr &engine, const StackFramePtr &last_stack_frame,
const EvalResultPtr &eval_result) {
MS_EXCEPTION_IF_NULL(engine);
MS_EXCEPTION_IF_NULL(last_stack_frame);
MS_EXCEPTION_IF_NULL(eval_result);
// Overwrite the result if func graph is stub.
EvalResultPtr result = eval_result;
if (last_stack_frame->func_graph()->stub()) {

View File

@ -80,6 +80,7 @@ size_t StackFrameDepth() { return stack_frame_depth; }
size_t StackFrameMaxDepth() { return stack_frame_max_depth; }
bool IsIntermediateAbstract(const AbstractBasePtr &arg_spec) {
MS_EXCEPTION_IF_NULL(arg_spec);
if (dyn_cast<AbstractScalar>(arg_spec)) {
auto v = arg_spec->GetValueTrack();
if (v->isa<SymbolicKeyInstance>()) {
@ -91,6 +92,7 @@ bool IsIntermediateAbstract(const AbstractBasePtr &arg_spec) {
AbstractBasePtr IntermediateJoin(const AbstractBasePtr &arg1, const AbstractBasePtr &arg2) {
if (dyn_cast<AbstractScalar>(arg1) && dyn_cast<AbstractScalar>(arg2)) {
MS_EXCEPTION_IF_NULL(arg1);
return arg1->Join(arg2);
}
return nullptr;
@ -121,6 +123,7 @@ AnalysisResult AnalysisEngine::Run(const FuncGraphPtr &func_graph, const Abstrac
AnalysisSchedule::GetInstance().Reset();
AnalysisResult result;
try {
MS_EXCEPTION_IF_NULL(func_graph);
ConfigPtrList args_conf_list;
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(args_conf_list),
[](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); });
@ -241,42 +244,6 @@ EvalResultPtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) {
return eval_result;
}
void AnalysisEngine::CheckNoStackInSameFuncGraph(const AnfNodeConfigPtr &conf) {
auto &list = trace::GetCNodeDebugStack();
if (list.empty()) {
return;
}
auto &previous_stack = list.back();
MS_EXCEPTION_IF_NULL(previous_stack->node());
MS_EXCEPTION_IF_NULL(conf->node());
auto previous_cnode_fg = previous_stack->node()->func_graph();
auto current_cnode_fg = conf->node()->func_graph();
if (previous_cnode_fg != current_cnode_fg) { // Normal.
return;
}
if (forward_count_ != 0) { // Ignore Forward Config.
return;
}
auto &graph_stack = trace::GetCurrenGraphEvalStack();
if (graph_stack.empty()) {
return;
}
auto top_context = graph_stack.back().first;
auto top_context_fg = top_context->func_graph();
if (current_cnode_fg != top_context_fg) { // Ignore FV call.
return;
}
MS_LOG(ERROR) << "Should not use call stack in the same function: " << top_context_fg->ToString() << ", for "
<< conf->node()->DebugString(2);
for (size_t i = 0; i < list.size(); ++i) {
auto old_conf = list[i];
MS_LOG(ERROR) << " #" << i << ": " << old_conf->node()->DebugString(2) << ", in "
<< old_conf->context()->func_graph()->ToString();
}
DumpIR("use_stack_error.ir", conf->node()->func_graph());
MS_LOG(EXCEPTION) << "To check above CNode stack and dumped use_stack_error.ir";
}
AbstractBasePtr AnalysisEngine::EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf) {
MS_EXCEPTION_IF_NULL(conf);
MS_EXCEPTION_IF_NULL(value_node);
@ -344,6 +311,7 @@ EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConf
}
EvalResultPtr AnalysisEngine::Execute(const AbstractFunctionPtr &func, const AbstractBasePtrList &args_spec_list) {
MS_EXCEPTION_IF_NULL(func);
ConfigPtrList args_conf_list;
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(args_conf_list),
[](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); });
@ -580,6 +548,8 @@ EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) {
}
EvalResultPtr AnalysisEngine::ForwardConfig(const AnfNodeConfigPtr &orig_conf, const AnfNodeConfigPtr new_conf) {
MS_EXCEPTION_IF_NULL(orig_conf);
MS_EXCEPTION_IF_NULL(new_conf);
// Use anfnode_config_map_[orig_conf] = new_conf will require AnfNodeConfig provide copy constructor.
(void)anfnode_config_map_.emplace(orig_conf, new_conf);
MS_LOG(DEBUG) << "Forward orig_conf: " << orig_conf->node()->DebugString()
@ -590,6 +560,7 @@ EvalResultPtr AnalysisEngine::ForwardConfig(const AnfNodeConfigPtr &orig_conf, c
if (new_conf->node()->isa<CNode>()) {
auto new_cnode = new_conf->node()->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(new_cnode);
MS_EXCEPTION_IF_NULL(old_cnode->func_graph());
if (old_cnode->func_graph() == new_cnode->func_graph()) {
MS_LOG(DEBUG) << "Try to remove forward node from order list, forward node: " << new_cnode->ToString()
<< ", as origin node should be in order list, origin_node: " << old_cnode->ToString();
@ -622,6 +593,7 @@ EvalResultPtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr>
}
void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator, const FuncGraphPtr &possible_parent_fg) {
MS_EXCEPTION_IF_NULL(evaluator);
static std::mutex fg_lock;
std::lock_guard<std::mutex> infer_lock(fg_lock);
auto fg_eval = evaluator->cast<FuncGraphEvaluatorPtr>();
@ -650,6 +622,8 @@ void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator, const Fu
EvaluatorPtr AnalysisEngine::HandleNestedRecursion(const std::vector<EvaluatorPtr> &evaluators,
const EvaluatorPtr &eval, const AbstractBasePtrList &args_spec_list,
const EvalTraceRevIter &it, bool *continue_flag) {
MS_EXCEPTION_IF_NULL(continue_flag);
MS_EXCEPTION_IF_NULL(eval);
*continue_flag = false;
// Find latest entry function to handle nested recursion.
EvaluatorPtr latest_entry = eval;
@ -769,6 +743,7 @@ EvalResultPtr AnalysisEngine::ProcessEvalResults(const AbstractBasePtrList &out_
}
bool NeedWaitForBranches(const AbstractBasePtr &abstract) {
MS_EXCEPTION_IF_NULL(abstract);
if (abstract->isa<AbstractFunction>()) {
return true;
}
@ -860,6 +835,7 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::ve
std::vector<AsyncAbstractPtr> branchAsyncResults;
for (auto &evaluator : evaluators) {
MS_EXCEPTION_IF_NULL(evaluator);
SetUndeterminedFlag(evaluator, possible_parent_fg);
AsyncAbstractPtr branchAsyncResult = std::make_shared<AsyncAbstract>();
// Control the order to run.
@ -935,7 +911,6 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Evalua
for (auto eval : evaluators) {
MS_EXCEPTION_IF_NULL(eval);
(void)SetUndeterminedFlag(eval, possible_parent_fg);
const auto current_inf = EvaluatorArgs(eval, args_spec_list);
MS_LOG(DEBUG) << "Check Evaluator " << eval->ToString();
// If current evaluator is under tracing, then skip current evaluator to avoid recursively evaluating.
@ -1006,6 +981,7 @@ abstract::AbstractBasePtr MakeAbstractClosure(const PrimitivePtr &primitive, con
}
AbstractBasePtr ToAbstract(const ValuePtr &value, const AnalysisContextPtr &context, const AnfNodeConfigPtr &conf) {
MS_EXCEPTION_IF_NULL(value);
AnfNodePtr anf_node = nullptr;
if (conf != nullptr) {
anf_node = conf->node();

View File

@ -282,7 +282,6 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
std::unordered_map<PrimitivePyPtr, EvaluatorPtr> prim_py_evaluators_;
void CheckNoStackInSameFuncGraph(const AnfNodeConfigPtr &conf);
bool enable_recursive_eval() const { return enable_recursive_eval_; }
static EvalResultPtr ProcessEvalResults(const AbstractBasePtrList &out_specs, const AnfNodePtr &node);

View File

@ -48,6 +48,7 @@ void ValidateOperation(const AnfNodePtr &node) {
// Primitive must in whitelist
auto prim = GetValueNode<PrimitivePtr>(node);
MS_EXCEPTION_IF_NULL(prim);
if (abstract::IsInWhiteList(prim)) {
return;
}
@ -70,6 +71,7 @@ void ValidateOperation(const AnfNodePtr &node) {
}
bool CheckAbstractScalar(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
AbstractBasePtr ptrBase = node->abstract();
if (ptrBase->isa<AbstractScalar>()) {
TypePtr ptrType = ptrBase->GetTypeTrack();