forked from mindspore-Ecosystem/mindspore
code self check
This commit is contained in:
parent
e187cfc889
commit
8e5a250c21
|
@ -153,6 +153,7 @@ abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraph
|
|||
|
||||
FuncGraphPtr ProgramSpecialize(const ResourcePtr &res, const FuncGraphPtr &func_graph,
|
||||
const abstract::AnalysisContextPtr &context) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
MS_LOG(DEBUG) << "ProgramSpecialize start";
|
||||
abstract::ProgramSpecializer spc(res->engine());
|
||||
FuncGraphPtr result = spc.Run(func_graph, context);
|
||||
|
@ -165,6 +166,7 @@ FuncGraphPtr ProgramSpecialize(const ResourcePtr &res, const FuncGraphPtr &func_
|
|||
|
||||
FuncGraphPtr Renormalize(const ResourcePtr &res, const FuncGraphPtr &func_graph,
|
||||
const abstract::AbstractBasePtrList &args_spec) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
MS_LOG(DEBUG) << "Renormalize start";
|
||||
#ifdef ENABLE_PROFILE
|
||||
double t1 = GetTime();
|
||||
|
@ -250,6 +252,7 @@ void CheckRootInputShapeAndType(const ResourcePtr &res, const FuncGraphPtr &load
|
|||
}
|
||||
|
||||
bool ParseAction(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
if (!res->input()) {
|
||||
MS_LOG(EXCEPTION) << "Parse error";
|
||||
}
|
||||
|
@ -293,8 +296,8 @@ bool ParseAction(const ResourcePtr &res) {
|
|||
// graph1(x){base_graph(x, fv1, fv2)}, graph1(x){base_graph(x, fv3, fv4)}, base_graph(x, fv...){xxx,xxx}
|
||||
// all obj_map's graph shared base_graph
|
||||
bool CombineLikeGraphs(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
auto &obj_map = parse::data_converter::GetObjGraphs();
|
||||
|
||||
for (auto it : obj_map) {
|
||||
auto &graphs = it.second;
|
||||
MS_LOG(DEBUG) << "Start combine like graph:" << it.first << ", size:" << graphs.size();
|
||||
|
@ -313,6 +316,7 @@ bool CombineLikeGraphs(const ResourcePtr &res) {
|
|||
for (auto &fv : fg->paramter_obj_nodes()) {
|
||||
TraceGuard guard(std::make_shared<TraceCombileLikeGraphs>(fv->debug_info()));
|
||||
auto param = base_graph->add_parameter();
|
||||
MS_EXCEPTION_IF_NULL(res->manager());
|
||||
auto &node_users = res->manager()->node_users()[fv];
|
||||
for (auto &n : node_users) {
|
||||
// If the user is not in this graph, no need to change.
|
||||
|
@ -321,6 +325,7 @@ bool CombineLikeGraphs(const ResourcePtr &res) {
|
|||
continue;
|
||||
}
|
||||
auto repl_n = cloned->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(repl_n);
|
||||
repl_n->set_input(IntToSize(n.second), param);
|
||||
}
|
||||
}
|
||||
|
@ -346,6 +351,7 @@ bool CombineLikeGraphs(const ResourcePtr &res) {
|
|||
}
|
||||
|
||||
bool SymbolResolveAction(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
if (res->manager() == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "SymbolResolve error, manager is null";
|
||||
}
|
||||
|
@ -367,6 +373,7 @@ bool SymbolResolveAction(const ResourcePtr &res) {
|
|||
}
|
||||
|
||||
bool AutoMonadAction(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
if (res->manager() == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Auto-Monad failed, manager is null";
|
||||
}
|
||||
|
@ -379,6 +386,7 @@ bool AutoMonadAction(const ResourcePtr &res) {
|
|||
}
|
||||
|
||||
bool OrderEnforceAction(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
if (res->manager() == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Order-Enforce error, manager is null";
|
||||
}
|
||||
|
@ -391,6 +399,7 @@ bool OrderEnforceAction(const ResourcePtr &res) {
|
|||
}
|
||||
|
||||
bool RemoveRandomOpMonadAction(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
if (res->manager() == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Remove-Random-Op-Monad error, manager is null";
|
||||
}
|
||||
|
@ -403,6 +412,7 @@ bool RemoveRandomOpMonadAction(const ResourcePtr &res) {
|
|||
}
|
||||
|
||||
bool InferenceOptPrepareAction(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
if (res->manager() == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "InferenceOptPrepare error, manager is null.";
|
||||
}
|
||||
|
@ -413,6 +423,7 @@ bool InferenceOptPrepareAction(const ResourcePtr &res) {
|
|||
}
|
||||
|
||||
bool AbstractSpecializeAction(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
if (res->func_graph() == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "AbstractSpecialize error";
|
||||
}
|
||||
|
@ -428,8 +439,10 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
|
|||
// get the hyper parameter
|
||||
for (const auto ¶m : func_graph->parameters()) {
|
||||
auto param_node = std::static_pointer_cast<Parameter>(param);
|
||||
MS_EXCEPTION_IF_NULL(param_node);
|
||||
if (param_node->has_default()) {
|
||||
auto value = param_node->default_param();
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
auto abs_value = value->ToAbstract()->cast<abstract::AbstractTensorPtr>();
|
||||
auto ref_key = std::make_shared<RefKey>(param_node->name());
|
||||
auto abs_ref_key = ref_key->ToAbstract();
|
||||
|
@ -466,6 +479,7 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
|
|||
}
|
||||
|
||||
bool OptimizeAction(const ResourcePtr &res, const std::vector<PassItem> &passes) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
size_t counter = 0;
|
||||
for (auto &pass : passes) {
|
||||
WITH(MsProfile::GetProfile()->Step(pass.first))[&pass, &res, &counter]() {
|
||||
|
@ -513,16 +527,8 @@ bool VmOptimizeAction(const ResourcePtr &res) {
|
|||
return OptimizeAction(res, kVmPasses);
|
||||
}
|
||||
|
||||
bool PynativeOptimizeAction(const ResourcePtr &resource) {
|
||||
WITH(MsProfile::GetProfile())[&resource]() { (void)OptimizeAction(resource, kPynativePasses); };
|
||||
#ifdef ENABLE_PROFILE
|
||||
MsProfile::Print();
|
||||
MsProfile::Reset();
|
||||
#endif
|
||||
return true;
|
||||
}
|
||||
|
||||
bool PynativeElimOpt(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
if (res->manager() == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "PynativeElimOpt error, manager is null.";
|
||||
}
|
||||
|
@ -564,6 +570,7 @@ bool CheckGraphOutputConstOrParameter(const FuncGraphPtr &func_graph) {
|
|||
}
|
||||
|
||||
bool TaskEmitAction(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode &&
|
||||
CheckGraphOutputConstOrParameter(res->func_graph())) {
|
||||
return true;
|
||||
|
@ -613,6 +620,7 @@ bool TaskEmitAction(const ResourcePtr &res) {
|
|||
}
|
||||
|
||||
bool ExecuteAction(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode &&
|
||||
CheckGraphOutputConstOrParameter(res->func_graph())) {
|
||||
return true;
|
||||
|
@ -673,6 +681,7 @@ bool StartFLWorkerAction(const ResourcePtr &) {
|
|||
}
|
||||
|
||||
bool StartPSServerAction(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
FuncGraphPtr func_graph = res->func_graph();
|
||||
auto &ps = ps::ParameterServer::GetInstance();
|
||||
ps.Run(func_graph);
|
||||
|
@ -680,6 +689,7 @@ bool StartPSServerAction(const ResourcePtr &res) {
|
|||
}
|
||||
|
||||
bool StartServerAction(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
FuncGraphPtr func_graph = res->func_graph();
|
||||
const std::string &server_mode_ = ps::PSContext::instance()->server_mode();
|
||||
uint32_t worker_num = ps::PSContext::instance()->initial_worker_num();
|
||||
|
@ -735,6 +745,8 @@ bool StartPSSchedulerAction(const ResourcePtr &) {
|
|||
// Here we temporarily avoid the problem by skipping valuenode merging used by parallel related primitive,
|
||||
// the final solution will be proposed later as a parallel feature.
|
||||
bool KeepValueNodeDuplication(const AnfNodePtr &value_node, const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
MS_EXCEPTION_IF_NULL(res->manager());
|
||||
auto &node_users = res->manager()->node_users();
|
||||
auto &users = node_users[value_node];
|
||||
auto used_by_keep_value_prim =
|
||||
|
@ -747,6 +759,7 @@ bool KeepValueNodeDuplication(const AnfNodePtr &value_node, const ResourcePtr &r
|
|||
auto prim_node = cnode->input(0);
|
||||
if (IsValueNode<Primitive>(prim_node)) {
|
||||
auto prim = GetValue<PrimitivePtr>(prim_node->cast<ValueNodePtr>()->value());
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
// value_node is referenced by some parallel primitive
|
||||
return prim->HasAttr("keep_value_node_input");
|
||||
}
|
||||
|
@ -756,10 +769,11 @@ bool KeepValueNodeDuplication(const AnfNodePtr &value_node, const ResourcePtr &r
|
|||
}
|
||||
|
||||
bool RemoveValueNodeDuplicationsAction(const ResourcePtr &res) {
|
||||
if (res->func_graph() == nullptr) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
FuncGraphPtr func_graph = res->func_graph();
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Remove value node duplications error.";
|
||||
}
|
||||
FuncGraphPtr func_graph = res->func_graph();
|
||||
auto manager = res->manager();
|
||||
// Remove duplicated value nodes, due to replace operation, can't use reference.
|
||||
auto value_nodes = func_graph->value_nodes();
|
||||
|
@ -796,8 +810,8 @@ bool OptActionVmPyStub(const ResourcePtr &res) {
|
|||
if (ActionPyStub(res, opt::python_pass::Phase::OPT)) {
|
||||
if (opt::python_pass::PyPassManager::GetInstance()->ShouldRenorm()) {
|
||||
// Renomalize
|
||||
MS_EXCEPTION_IF_NULL(res->func_graph());
|
||||
FuncGraphPtr func_graph = res->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
abstract::AbstractBasePtrList args_spec;
|
||||
auto parameters = func_graph->parameters();
|
||||
(void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec),
|
||||
|
@ -817,8 +831,8 @@ bool OptActionGePyStub(const ResourcePtr &res) {
|
|||
if (ActionPyStub(res, opt::python_pass::Phase::OPT)) {
|
||||
if (opt::python_pass::PyPassManager::GetInstance()->ShouldRenorm()) {
|
||||
// Renomalize
|
||||
MS_EXCEPTION_IF_NULL(res->func_graph());
|
||||
FuncGraphPtr func_graph = res->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
abstract::AbstractBasePtrList args_spec;
|
||||
auto parameters = func_graph->parameters();
|
||||
(void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec),
|
||||
|
|
|
@ -36,7 +36,6 @@ bool AutoMonadAction(const ResourcePtr &res);
|
|||
bool AbstractSpecializeAction(const ResourcePtr &res);
|
||||
bool GeOptimizeAction(const ResourcePtr &res);
|
||||
bool VmOptimizeAction(const ResourcePtr &res);
|
||||
bool PynativeOptimizeAction(const ResourcePtr &res);
|
||||
bool PynativeElimOpt(const ResourcePtr &res);
|
||||
bool TaskEmitAction(const ResourcePtr &res);
|
||||
bool ExecuteAction(const ResourcePtr &res);
|
||||
|
|
|
@ -503,7 +503,7 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature
|
|||
}
|
||||
ValuePtr converted = nullptr;
|
||||
bool matched = false;
|
||||
auto &&converters = GetDataConverters();
|
||||
auto converters = GetDataConverters();
|
||||
for (auto &converter : converters) {
|
||||
if (converter->Matched(obj)) {
|
||||
converted = converter->ConvertPyObject(obj, use_signature, dtype);
|
||||
|
|
|
@ -63,7 +63,9 @@ static bool CanBeIsolatedNode(const std::string &var_name, const AnfNodePtr &nod
|
|||
|
||||
// Write variable records the variable name to corresponding node
|
||||
void FunctionBlock::WriteVariable(const std::string &var_name, const AnfNodePtr &node) {
|
||||
MS_LOG(DEBUG) << func_graph_->ToString() << " write var " << var_name << " with node " << node->DebugString();
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_LOG(DEBUG) << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " write var " << var_name << " with node "
|
||||
<< node->DebugString();
|
||||
auto [iter, is_new_name] = vars_.emplace(var_name, std::make_pair(node, false));
|
||||
if (!is_new_name) {
|
||||
// If a cnode variable with same name already existed but not used,
|
||||
|
@ -76,9 +78,10 @@ void FunctionBlock::WriteVariable(const std::string &var_name, const AnfNodePtr
|
|||
auto hidden_node = iter->second.first;
|
||||
auto is_isolated = CanBeIsolatedNode(var_name, hidden_node);
|
||||
if (!is_used && is_isolated) {
|
||||
MS_EXCEPTION_IF_NULL(hidden_node);
|
||||
MS_LOG(INFO) << "Isolated node found(Hidden), hidden_node: " << hidden_node->DebugString(2) << " is hidden by "
|
||||
<< node->DebugString(2) << " with the same name, var_name: " << var_name << ", block: " << this
|
||||
<< "/" << (func_graph() ? func_graph()->ToString() : "FG(Null)")
|
||||
<< "/" << (func_graph_ ? func_graph_->ToString() : "FG(Null)")
|
||||
<< ", Line: " << trace::GetDebugInfo(hidden_node->debug_info(), "", kSourceLineTipDiscard);
|
||||
AddIsolatedNode(hidden_node);
|
||||
}
|
||||
|
@ -124,7 +127,8 @@ AnfNodePtr FunctionBlock::ReadVariable(const std::string &var) {
|
|||
debug_info->set_name(var);
|
||||
TraceGuard guard(std::make_shared<TracePhi>(debug_info));
|
||||
ParameterPtr phi_param = std::make_shared<Parameter>(func_graph());
|
||||
MS_LOG(DEBUG) << func_graph_->ToString() << " generate phi node " << phi_param->ToString() << " for " << var;
|
||||
MS_LOG(DEBUG) << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " generate phi node "
|
||||
<< phi_param->ToString() << " for " << var;
|
||||
func_graph()->add_parameter(phi_param);
|
||||
phi_nodes_[phi_param] = var;
|
||||
WriteVariable(var, phi_param);
|
||||
|
@ -150,8 +154,9 @@ AnfNodePtr FunctionBlock::MakeResolveAstOp(const py::object &op) {
|
|||
|
||||
// Resolve class member, two possible: method, member variable
|
||||
AnfNodePtr FunctionBlock::MakeResolveClassMember(const std::string &attr) {
|
||||
py::object namespace_var =
|
||||
parser_.ast()->CallParseModFunction(PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, parser_.ast()->obj());
|
||||
auto ast = parser_.ast();
|
||||
MS_EXCEPTION_IF_NULL(ast);
|
||||
py::object namespace_var = ast->CallParseModFunction(PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, ast->obj());
|
||||
NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var);
|
||||
SymbolPtr symbol = std::make_shared<Symbol>(attr);
|
||||
return MakeResolve(name_space, symbol);
|
||||
|
@ -168,8 +173,13 @@ AnfNodePtr FunctionBlock::MakeResolveSymbol(const std::string &value) {
|
|||
auto bits_str = value.substr(start);
|
||||
return MakeResolveClassMember(bits_str);
|
||||
}
|
||||
py::tuple namespace_info = parser_.ast()->CallParserObjMethod(PYTHON_PARSE_GET_NAMESPACE_SYMBOL, value);
|
||||
auto ast = parser_.ast();
|
||||
MS_EXCEPTION_IF_NULL(ast);
|
||||
py::tuple namespace_info = ast->CallParserObjMethod(PYTHON_PARSE_GET_NAMESPACE_SYMBOL, value);
|
||||
const size_t namespace_info_size = 2;
|
||||
if (namespace_info.size() < namespace_info_size) {
|
||||
MS_EXCEPTION(NameError) << "namespace_info is less than 2";
|
||||
}
|
||||
// If namespace is None, the symbol is an undefined name or an unsupported builtin function.
|
||||
if (namespace_info[0].is_none()) {
|
||||
// If the size of namespace_var is greater than or equal to 3, the error information is stored in namespace_var[2].
|
||||
|
@ -179,17 +189,15 @@ AnfNodePtr FunctionBlock::MakeResolveSymbol(const std::string &value) {
|
|||
// If the size of namespace_var is less than 3, the default error information is used.
|
||||
MS_EXCEPTION(NameError) << "The name \'" << value << "\' is not defined.";
|
||||
}
|
||||
if (namespace_info.size() < namespace_info_size) {
|
||||
MS_EXCEPTION(NameError) << "namespace_info is less than 2";
|
||||
}
|
||||
|
||||
NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_SYMBOL_STR, namespace_info[0]);
|
||||
SymbolPtr symbol = std::make_shared<Symbol>(namespace_info[1].cast<std::string>());
|
||||
return MakeResolve(name_space, symbol);
|
||||
}
|
||||
|
||||
AnfNodePtr FunctionBlock::MakeResolveOperation(const std::string &value) {
|
||||
py::tuple namespace_var = parser_.ast()->CallParseModFunction(PYTHON_PARSE_GET_OPERATION_NAMESPACE_SYMBOL, value);
|
||||
auto ast = parser_.ast();
|
||||
MS_EXCEPTION_IF_NULL(ast);
|
||||
py::tuple namespace_var = ast->CallParseModFunction(PYTHON_PARSE_GET_OPERATION_NAMESPACE_SYMBOL, value);
|
||||
const size_t namespace_var_size = 2;
|
||||
if (namespace_var.size() < namespace_var_size) {
|
||||
MS_EXCEPTION(NameError) << "namespace_var is less than 2";
|
||||
|
@ -200,28 +208,32 @@ AnfNodePtr FunctionBlock::MakeResolveOperation(const std::string &value) {
|
|||
}
|
||||
|
||||
AnfNodePtr FunctionBlock::MakeResolve(const NameSpacePtr &name_space, const SymbolPtr &resolve_symbol) {
|
||||
MS_LOG(DEBUG) << "MakeResolve for " << ((std::string)py::str(name_space->obj())) << " , "
|
||||
<< ((std::string)resolve_symbol->symbol());
|
||||
MS_LOG(DEBUG) << "MakeResolve for " << (name_space ? (std::string)py::str(name_space->obj()) : "null namespace")
|
||||
<< " , " << (resolve_symbol ? (std::string)resolve_symbol->symbol() : "null resoleve symbol.");
|
||||
ValueNodePtr module_node = NewValueNode(name_space);
|
||||
ValueNodePtr symbol_node = NewValueNode(resolve_symbol);
|
||||
auto node = func_graph()->NewCNodeInOrder({NewValueNode(prim::kPrimResolve), module_node, symbol_node});
|
||||
auto node = func_graph_->NewCNodeInOrder({NewValueNode(prim::kPrimResolve), module_node, symbol_node});
|
||||
return node;
|
||||
}
|
||||
|
||||
// Add input for the block's phi parameter
|
||||
void FunctionBlock::SetPhiArgument(const ParameterPtr &phi) {
|
||||
MS_EXCEPTION_IF_NULL(phi);
|
||||
TraceGuard trace_guard(std::make_shared<TraceResolve>(phi->debug_info()));
|
||||
std::string var = phi_nodes_[phi];
|
||||
MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " set phi " << phi->ToString() << " for var " << var;
|
||||
MS_LOG(DEBUG) << "graph " << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " set phi " << phi->ToString()
|
||||
<< " for var " << var;
|
||||
auto removable = CollectRemovablePhi(phi);
|
||||
// If the phi node is not necessary, not need to add to jumps_ of the prev blocks.
|
||||
if (removable) {
|
||||
MS_LOG(DEBUG) << "remove the phi when call graph " << func_graph_->ToString() << " var " << var;
|
||||
MS_LOG(DEBUG) << "remove the phi when call graph " << (func_graph_ ? func_graph_->ToString() : "FG(Null)")
|
||||
<< " var " << var;
|
||||
return;
|
||||
}
|
||||
for (auto &pred : prev_blocks_) {
|
||||
MS_EXCEPTION_IF_NULL(pred);
|
||||
MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " pred_blocks_ " << pred->func_graph_->ToString();
|
||||
MS_LOG(DEBUG) << "graph " << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " pred_blocks_ "
|
||||
<< (pred->func_graph_ ? pred->func_graph_->ToString() : "FG(Null)");
|
||||
AnfNodePtr arg_node = pred->ReadVariable(var);
|
||||
CNodePtr jump = pred->jumps_[this];
|
||||
MS_EXCEPTION_IF_NULL(jump);
|
||||
|
@ -235,18 +247,18 @@ AnfNodePtr FunctionBlock::SearchReplaceNode(const std::string &var, const Parame
|
|||
MS_EXCEPTION_IF_NULL(prev);
|
||||
AnfNodePtr temp_node = prev->ReadVariable(var);
|
||||
MS_EXCEPTION_IF_NULL(temp_node);
|
||||
MS_LOG(DEBUG) << "graph " << prev->func_graph_->ToString() << " phi " << phi->ToString() << " for var " << var
|
||||
<< " is " << temp_node->DebugString();
|
||||
MS_LOG(DEBUG) << "graph " << (prev->func_graph_ ? prev->func_graph_->ToString() : "FG(Null)") << " phi "
|
||||
<< (phi ? phi->ToString() : "null") << " for var " << var << " is " << temp_node->DebugString();
|
||||
if (temp_node != phi) {
|
||||
if (arg_node == nullptr) {
|
||||
arg_node = temp_node;
|
||||
MS_LOG(DEBUG) << "graph " << prev->func_graph_->ToString() << " phi " << phi->ToString()
|
||||
<< " may be replaced by node " << arg_node->DebugString();
|
||||
MS_LOG(DEBUG) << "graph " << (prev->func_graph_ ? prev->func_graph_->ToString() : "FG(Null)") << " phi "
|
||||
<< (phi ? phi->ToString() : "null") << " may be replaced by node " << arg_node->DebugString();
|
||||
} else if (temp_node == arg_node) {
|
||||
MS_LOG(DEBUG) << "graph " << prev->func_graph_->ToString() << " phi " << phi->ToString() << " is same as node "
|
||||
<< arg_node->DebugString();
|
||||
MS_LOG(DEBUG) << "graph " << (prev->func_graph_ ? prev->func_graph_->ToString() : "FG(Null)") << " phi "
|
||||
<< (phi ? phi->ToString() : "null") << " is same as node " << arg_node->DebugString();
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "phi " << phi->ToString()
|
||||
MS_LOG(DEBUG) << "phi " << (phi ? phi->ToString() : "null")
|
||||
<< " cannot be removed as it assigns to different node. node1: " << arg_node->DebugString()
|
||||
<< ", node2: " << temp_node->DebugString();
|
||||
return nullptr;
|
||||
|
@ -283,8 +295,8 @@ bool FunctionBlock::CollectRemovablePhi(const ParameterPtr &phi) {
|
|||
AnfNodePtr arg_node = SearchReplaceNode(var, phi);
|
||||
if (arg_node != nullptr) {
|
||||
arg_node->set_debug_info(phi->debug_info());
|
||||
MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " phi " << phi->ToString() << " can be replaced with "
|
||||
<< arg_node->DebugString();
|
||||
MS_LOG(DEBUG) << "graph " << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " phi " << phi->ToString()
|
||||
<< " can be replaced with " << arg_node->DebugString();
|
||||
// Replace var with new one. This equal to statement in TR "v0 is immediately replaced by v1."
|
||||
WriteVariable(var, arg_node);
|
||||
removable_phis_[phi] = arg_node;
|
||||
|
@ -301,9 +313,10 @@ bool FunctionBlock::CollectRemovablePhi(const ParameterPtr &phi) {
|
|||
if (phi_iter.second->isa<Parameter>()) {
|
||||
const auto ¶m = phi_iter.second->cast<ParameterPtr>();
|
||||
if (param == phi) {
|
||||
MS_LOG(DEBUG) << "graph " << prev->func_graph_->ToString() << " var " << phi_iter.first->DebugString()
|
||||
<< " can be replaced from " << param->DebugString() << " with " << arg_node->DebugString()
|
||||
<< " in graph " << arg_node->func_graph()->ToString();
|
||||
MS_LOG(DEBUG) << "graph " << (prev->func_graph_ ? prev->func_graph_->ToString() : "FG(Null)") << " var "
|
||||
<< phi_iter.first->DebugString() << " can be replaced from " << param->DebugString()
|
||||
<< " with " << arg_node->DebugString() << " in graph "
|
||||
<< (arg_node->func_graph() ? arg_node->func_graph()->ToString() : "FG(Null)");
|
||||
prev->removable_phis_[phi_iter.first] = arg_node;
|
||||
}
|
||||
}
|
||||
|
@ -329,22 +342,25 @@ void FunctionBlock::Mature() {
|
|||
|
||||
// Force the condition node to bool using bool operation
|
||||
CNodePtr FunctionBlock::ForceToBoolNode(const AnfNodePtr &cond) {
|
||||
MS_EXCEPTION_IF_NULL(cond);
|
||||
TraceGuard trace_guard(std::make_shared<TraceForceBool>(cond->debug_info()));
|
||||
CNodePtr op_apply_node = func_graph()->NewCNodeInOrder({MakeResolveOperation(NAMED_PRIMITIVE_BOOL), cond});
|
||||
CNodePtr op_apply_node = func_graph_->NewCNodeInOrder({MakeResolveOperation(NAMED_PRIMITIVE_BOOL), cond});
|
||||
return op_apply_node;
|
||||
}
|
||||
|
||||
CNodePtr FunctionBlock::ForceToWhileCond(const AnfNodePtr &cond) {
|
||||
MS_EXCEPTION_IF_NULL(cond);
|
||||
TraceGuard trace_guard(std::make_shared<TraceForceWhileCond>(cond->debug_info()));
|
||||
CNodePtr op_apply_node = func_graph()->NewCNodeInOrder({MakeResolveOperation("while_cond"), cond});
|
||||
CNodePtr op_apply_node = func_graph_->NewCNodeInOrder({MakeResolveOperation("while_cond"), cond});
|
||||
return op_apply_node;
|
||||
}
|
||||
|
||||
// Perform a jump from this block to target block
|
||||
void FunctionBlock::Jump(const FunctionBlockPtr &target_block, const AnfNodePtr &node) {
|
||||
if (func_graph()->get_return() != nullptr) {
|
||||
MS_EXCEPTION_IF_NULL(target_block);
|
||||
if (func_graph_->get_return() != nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failure: have return node! NodeInfo: "
|
||||
<< trace::GetDebugInfo(func_graph()->get_return()->debug_info());
|
||||
<< trace::GetDebugInfo(func_graph_->get_return()->debug_info());
|
||||
}
|
||||
std::vector<AnfNodePtr> input_nodes;
|
||||
input_nodes.emplace_back(NewValueNode(target_block->func_graph()));
|
||||
|
@ -352,25 +368,27 @@ void FunctionBlock::Jump(const FunctionBlockPtr &target_block, const AnfNodePtr
|
|||
input_nodes.emplace_back(node);
|
||||
}
|
||||
|
||||
CNodePtr jump = func_graph()->NewCNodeInOrder(input_nodes);
|
||||
CNodePtr jump = func_graph_->NewCNodeInOrder(input_nodes);
|
||||
jumps_[target_block.get()] = jump;
|
||||
target_block->AddPrevBlock(shared_from_this());
|
||||
func_graph()->set_output(jump);
|
||||
func_graph_->set_output(jump);
|
||||
}
|
||||
|
||||
// Perform a conditional jump using switch operation.
|
||||
// The first CNode select graph with condition, and than execute this graph
|
||||
void FunctionBlock::ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr &true_block,
|
||||
const FunctionBlockPtr &false_block, bool unroll_loop) {
|
||||
if (func_graph()->get_return() != nullptr) {
|
||||
MS_EXCEPTION_IF_NULL(true_block);
|
||||
MS_EXCEPTION_IF_NULL(false_block);
|
||||
if (func_graph_->get_return() != nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failure: have return node! NodeInfo: "
|
||||
<< trace::GetDebugInfo(func_graph()->get_return()->debug_info());
|
||||
<< trace::GetDebugInfo(func_graph_->get_return()->debug_info());
|
||||
}
|
||||
CNodePtr switch_app =
|
||||
func_graph()->NewCNodeInOrder({NewValueNode(prim::kPrimSwitch), condNode, NewValueNode(true_block->func_graph()),
|
||||
NewValueNode(false_block->func_graph())});
|
||||
CNodePtr switch_app_new = func_graph()->NewCNodeInOrder({switch_app});
|
||||
func_graph()->set_output(switch_app_new);
|
||||
func_graph_->NewCNodeInOrder({NewValueNode(prim::kPrimSwitch), condNode, NewValueNode(true_block->func_graph()),
|
||||
NewValueNode(false_block->func_graph())});
|
||||
CNodePtr switch_app_new = func_graph_->NewCNodeInOrder({switch_app});
|
||||
func_graph_->set_output(switch_app_new);
|
||||
}
|
||||
|
||||
// Create cnode for the assign statement like 'self.target = source'.
|
||||
|
@ -381,7 +399,7 @@ void FunctionBlock::SetStateAssign(const AnfNodePtr &target, const AnfNodePtr &s
|
|||
ValueNodePtr assign_op = NewValueNode(prim::GetPythonOps(primitive_name, module_name, true));
|
||||
auto assign_node = func_graph_->NewCNodeInOrder({assign_op, target, source});
|
||||
MS_LOG(DEBUG) << "Isolated node found(Assign), assign_node: " << assign_node->DebugString(2) << ", block: " << this
|
||||
<< "/" << (func_graph() ? func_graph()->ToString() : "FG(Null)")
|
||||
<< "/" << func_graph_->ToString()
|
||||
<< ", Line: " << trace::GetDebugInfo(assign_node->debug_info(), "", kSourceLineTipDiscard);
|
||||
AddIsolatedNode(assign_node);
|
||||
}
|
||||
|
@ -430,15 +448,15 @@ void FunctionBlock::AttachIsolatedNodesBeforeReturn() {
|
|||
if (isolated_nodes_.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> states;
|
||||
states.emplace_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
for (auto &node : isolated_nodes_) {
|
||||
MS_LOG(DEBUG) << "Adding dependency, node: " << node->DebugString(2) << " in " << func_graph()->ToString();
|
||||
if (node->func_graph() == func_graph()) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_LOG(DEBUG) << "Adding dependency, node: " << node->DebugString(2) << " in " << func_graph_->ToString();
|
||||
if (node->func_graph() == func_graph_) {
|
||||
states.emplace_back(node);
|
||||
} else {
|
||||
MS_LOG(INFO) << "Ignored FV dependency, node: " << node->DebugString(2) << " in " << func_graph()->ToString();
|
||||
MS_LOG(INFO) << "Ignored FV dependency, node: " << node->DebugString(2) << " in " << func_graph_->ToString();
|
||||
}
|
||||
}
|
||||
isolated_nodes_.clear();
|
||||
|
@ -452,11 +470,11 @@ void FunctionBlock::AttachIsolatedNodesBeforeReturn() {
|
|||
// do not need to MakeTuple, just use the node.
|
||||
state = states[1];
|
||||
} else {
|
||||
state = func_graph()->NewCNode(states);
|
||||
state = func_graph_->NewCNode(states);
|
||||
}
|
||||
|
||||
AnfNodePtr old_output = nullptr;
|
||||
auto return_node = func_graph()->get_return();
|
||||
auto return_node = func_graph_->get_return();
|
||||
if (return_node) {
|
||||
const size_t return_input_size = 2;
|
||||
if (return_node->inputs().size() < return_input_size) {
|
||||
|
@ -466,15 +484,15 @@ void FunctionBlock::AttachIsolatedNodesBeforeReturn() {
|
|||
} else {
|
||||
old_output = NewValueNode(kNone);
|
||||
}
|
||||
AnfNodePtr stop_grad_node = func_graph()->NewCNode({NewValueNode(prim::kPrimStopGradient), state});
|
||||
CNodePtr depend_node = func_graph()->NewCNode({NewValueNode(prim::kPrimDepend), old_output, stop_grad_node});
|
||||
AnfNodePtr stop_grad_node = func_graph_->NewCNode({NewValueNode(prim::kPrimStopGradient), state});
|
||||
CNodePtr depend_node = func_graph_->NewCNode({NewValueNode(prim::kPrimDepend), old_output, stop_grad_node});
|
||||
// We add this attribute for @constexpr use scene, since we must infer them before other nodes.
|
||||
// That means isolated nodes will be evaluated first. It's not complete, but works in most scenes.
|
||||
depend_node->AddAttr(kAttrTopoSortRhsFirst, MakeValue(true));
|
||||
MS_EXCEPTION_IF_NULL(state);
|
||||
MS_LOG(INFO) << "Attached for side-effect nodes, depend_node: " << depend_node->DebugString()
|
||||
<< ", state: " << state->DebugString(2);
|
||||
func_graph()->set_output(depend_node, true);
|
||||
func_graph_->set_output(depend_node, true);
|
||||
}
|
||||
} // namespace parse
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -61,6 +61,7 @@ FuncGraphPtr ParsePythonCode(const py::object &obj, const std::string &python_mo
|
|||
}
|
||||
|
||||
TypePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP32)) {
|
||||
return kFloat32;
|
||||
} else if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP16)) {
|
||||
|
@ -72,6 +73,7 @@ TypePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph) {
|
|||
|
||||
// If any mixed precision flag add a cast node after the parameter node.
|
||||
AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
TypePtr dst_type;
|
||||
if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP32)) {
|
||||
dst_type = kFloat32;
|
||||
|
@ -146,12 +148,18 @@ void Parser::CleanParserResource() {
|
|||
void CheckFuncReturn(const FuncGraphPtr &fn, const std::shared_ptr<ParseAst> &ast) {
|
||||
// Check whether the functions referred by this function and itself are missing 'return' statement
|
||||
auto mng = Manage(fn, false);
|
||||
MS_EXCEPTION_IF_NULL(ast);
|
||||
for (const auto &func_graph : mng->func_graphs()) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
if (func_graph->get_return() != nullptr) {
|
||||
continue;
|
||||
}
|
||||
py::object node = ast->GetAstNode();
|
||||
py::list ret = ast->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node);
|
||||
constexpr auto kMinListSize = 2;
|
||||
if (ret.size() < kMinListSize) {
|
||||
MS_LOG(EXCEPTION) << "list size:" << ret.size() << " is less than 2.";
|
||||
}
|
||||
py::str desc =
|
||||
python_adapter::CallPyModFn(ast->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, ast->function(), ret[0], ret[1]);
|
||||
MS_EXCEPTION(TypeError) << "Missing return statement in " << desc.cast<std::string>() << ".";
|
||||
|
@ -178,49 +186,55 @@ FuncGraphPtr Parser::ParseFuncGraph() {
|
|||
void Parser::GenerateArgsNodeForFunction(const FunctionBlockPtr &block, const py::object &fn_node) {
|
||||
py::object func_args = python_adapter::GetPyObjAttr(fn_node, "args");
|
||||
py::object var_arg_node = python_adapter::GetPyObjAttr(func_args, "vararg");
|
||||
block->func_graph()->set_has_vararg(!py::isinstance<py::none>(var_arg_node));
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
auto block_fg = block->func_graph();
|
||||
block_fg->set_has_vararg(!py::isinstance<py::none>(var_arg_node));
|
||||
|
||||
py::object kw_arg_node = python_adapter::GetPyObjAttr(func_args, "kwarg");
|
||||
block->func_graph()->set_has_kwarg(!py::isinstance<py::none>(kw_arg_node));
|
||||
block_fg->set_has_kwarg(!py::isinstance<py::none>(kw_arg_node));
|
||||
|
||||
py::list kwonly_args = python_adapter::GetPyObjAttr(func_args, "kwonlyargs");
|
||||
block->func_graph()->set_kwonlyargs_count(SizeToLong(kwonly_args.size()));
|
||||
block_fg->set_kwonlyargs_count(SizeToLong(kwonly_args.size()));
|
||||
|
||||
MS_EXCEPTION_IF_NULL(ast_);
|
||||
py::list args = ast_->GetArgs(fn_node);
|
||||
for (std::size_t i = 0; i < args.size(); i++) {
|
||||
std::string arg_name = py::cast<std::string>(args[i].attr("arg"));
|
||||
if (ast()->target_type() == PARSE_TARGET_OBJECT_INSTANCE) {
|
||||
if (ast_->target_type() == PARSE_TARGET_OBJECT_INSTANCE) {
|
||||
if (arg_name == "self") {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
TraceGuard guard(GetLocation(args[i]));
|
||||
auto para_node = std::make_shared<Parameter>(block->func_graph());
|
||||
auto para_node = std::make_shared<Parameter>(block_fg);
|
||||
MS_EXCEPTION_IF_NULL(para_node);
|
||||
para_node->set_name(arg_name);
|
||||
para_node->debug_info()->set_name(arg_name);
|
||||
block->func_graph()->add_parameter(para_node);
|
||||
AnfNodePtr para_after_cast = GetMixedPrecisionCastHelp(block->func_graph(), para_node);
|
||||
block_fg->add_parameter(para_node);
|
||||
AnfNodePtr para_after_cast = GetMixedPrecisionCastHelp(block_fg, para_node);
|
||||
block->WriteVariable(arg_name, para_after_cast);
|
||||
MS_LOG(DEBUG) << "The arg[" << i << "] is " << arg_name;
|
||||
}
|
||||
}
|
||||
|
||||
void Parser::GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block, const py::object &fn_node) {
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
py::list defaults = ast_->GetArgsDefaultValues(fn_node);
|
||||
py::list args = ast_->GetArgs(fn_node);
|
||||
std::vector<std::string> namelist_for_default_value;
|
||||
std::vector<AnfNodePtr> default_values;
|
||||
for (std::size_t i = 0; i < args.size(); i++) {
|
||||
std::string arg_name = py::cast<std::string>(args[i].attr("arg"));
|
||||
if (ast()->target_type() == PARSE_TARGET_OBJECT_INSTANCE) {
|
||||
if (ast_->target_type() == PARSE_TARGET_OBJECT_INSTANCE) {
|
||||
if (arg_name == "self") {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
namelist_for_default_value.push_back(arg_name);
|
||||
if (i >= defaults.size()) {
|
||||
MS_LOG(EXCEPTION) << "Index:" << i << " out of range:" << defaults.size();
|
||||
}
|
||||
if (py::isinstance<py::none>(defaults[i])) {
|
||||
default_values.push_back(NewValueNode(kNull));
|
||||
} else {
|
||||
|
@ -232,7 +246,7 @@ void Parser::GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block,
|
|||
|
||||
ScopePtr Parser::GetScopeForParseFunction() {
|
||||
ScopePtr scope = ScopeManager::GetInstance().GetCurrentScope();
|
||||
if (ast()->target_type() == PARSE_TARGET_OBJECT_INSTANCE) {
|
||||
if (ast_->target_type() == PARSE_TARGET_OBJECT_INSTANCE) {
|
||||
py::object scope_str = python_adapter::CallPyFn(PYTHON_MOD_PARSE_MODULE, PYTHON_PARSE_GET_SCOPE_NAME, ast_->obj());
|
||||
if (!py::isinstance<py::none>(scope_str)) {
|
||||
auto scope_name = py::cast<std::string>(scope_str);
|
||||
|
@ -246,7 +260,7 @@ FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlo
|
|||
ScopePtr scope = GetScopeForParseFunction();
|
||||
// The node created in the parsefunction context, will inherit the scope created using scope_guard
|
||||
ScopeGuard scope_guard(scope);
|
||||
TraceGuard trace_guard(data_converter::GetObjKey(ast()->obj())[0], GetLocation(node));
|
||||
TraceGuard trace_guard(data_converter::GetObjKey(ast_->obj())[0], GetLocation(node));
|
||||
FunctionBlockPtr pFunBlock = MakeFunctionBlock(*this);
|
||||
if (block != nullptr) {
|
||||
pFunBlock->AddPrevBlock(block);
|
||||
|
@ -257,14 +271,12 @@ FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlo
|
|||
auto current_fg = pFunBlock->func_graph();
|
||||
auto function_name = py::cast<std::string>(python_adapter::GetPyObjAttr(node, "name"));
|
||||
MS_LOG(DEBUG) << "The function name is " << function_name;
|
||||
|
||||
current_fg->debug_info()->set_name(function_name);
|
||||
MS_EXCEPTION_IF_NULL(ast_);
|
||||
py::list deco_list = node.attr("decorator_list");
|
||||
if (!deco_list.empty()) {
|
||||
current_fg->debug_info()->set_deco_location(GetLocation(deco_list));
|
||||
}
|
||||
|
||||
bool set_flag = UpdateFuncGraphFlags(ast_->function(), current_fg);
|
||||
if (!ast_->obj().is(ast_->function())) {
|
||||
set_flag = set_flag && UpdateFuncGraphFlags(ast_->obj(), current_fg);
|
||||
|
@ -289,6 +301,7 @@ FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlo
|
|||
|
||||
// Add unused variables as isolate nodes.
|
||||
for (auto &func_block : func_block_list_) {
|
||||
MS_EXCEPTION_IF_NULL(func_block);
|
||||
if (func_block->func_graph()->get_return() != nullptr) {
|
||||
// Find unused variables.
|
||||
func_block->FindIsolatedNodes();
|
||||
|
@ -313,6 +326,7 @@ FunctionBlockPtr Parser::ParseStatements(FunctionBlockPtr block, const py::objec
|
|||
for (size_t i = 0; i < count; ++i) {
|
||||
auto node = node_list[i];
|
||||
block = ParseStatement(block, node);
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
// Insert appropriate depended items for the function block if it has a return node
|
||||
if (block->func_graph()->get_return() != nullptr) {
|
||||
// Skip statements after 'return' (or 'break', 'continue').
|
||||
|
@ -352,9 +366,8 @@ AnfNodePtr Parser::ParseExprNode(const FunctionBlockPtr &block, const py::object
|
|||
// Check the node type
|
||||
AstMainType node_main_type = node_type->main_type();
|
||||
if (node_main_type != AST_MAIN_TYPE_EXPR) {
|
||||
MS_LOG(ERROR) << "Node type is error : " << node_main_type;
|
||||
errcode_ = PARSE_NODE_TYPE_NO_MATCH;
|
||||
return nullptr;
|
||||
MS_LOG(EXCEPTION) << "Node type is error : " << node_main_type;
|
||||
}
|
||||
// Call the process function
|
||||
std::string node_name = node_type->node_name();
|
||||
|
@ -381,9 +394,15 @@ FunctionBlockPtr Parser::ParseExpr(const FunctionBlockPtr &block, const py::obje
|
|||
// False, None, None
|
||||
//
|
||||
// Check the expand info result
|
||||
if (expand_info.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Empty expand_info.";
|
||||
}
|
||||
auto is_expand = py::cast<bool>(expand_info[0]);
|
||||
if (is_expand) {
|
||||
// Process the expr statement
|
||||
if (expand_info.size() < 2) {
|
||||
MS_LOG(EXCEPTION) << "expand_info size:" << expand_info.size() << " less than 2.";
|
||||
}
|
||||
py::object value_object = expand_info[1];
|
||||
// Make a Expr CNode.
|
||||
AnfNodePtr call_node = ParseExprNode(block, value_object);
|
||||
|
@ -429,6 +448,8 @@ LocationPtr Parser::GetLocation(const py::object &node) const {
|
|||
|
||||
void Parser::MakeConditionBlocks(const FunctionBlockPtr &pre_block, const FunctionBlockPtr &true_block,
|
||||
const FunctionBlockPtr &false_block) {
|
||||
MS_EXCEPTION_IF_NULL(true_block);
|
||||
MS_EXCEPTION_IF_NULL(false_block);
|
||||
true_block->AddPrevBlock(pre_block);
|
||||
true_block->Mature();
|
||||
|
||||
|
@ -445,10 +466,9 @@ FunctionBlockPtr Parser::ParseReturn(const FunctionBlockPtr &block, const py::ob
|
|||
py::object value = python_adapter::GetPyObjAttr(node, "value");
|
||||
AnfNodePtr pReturnStatementNode = ParseExprNode(block, value);
|
||||
// Create the cnode
|
||||
CNodePtr pReturnCNode = block->func_graph()->NewCNodeInOrder({pReturnValueNode, pReturnStatementNode});
|
||||
|
||||
block->func_graph()->set_return(pReturnCNode);
|
||||
|
||||
auto block_fg = block->func_graph();
|
||||
CNodePtr pReturnCNode = block_fg->NewCNodeInOrder({pReturnValueNode, pReturnStatementNode});
|
||||
block_fg->set_return(pReturnCNode);
|
||||
return block;
|
||||
}
|
||||
|
||||
|
@ -462,17 +482,17 @@ AnfNodePtr Parser::ParseBinOp(const FunctionBlockPtr &block, const py::object &n
|
|||
// Create left and right ANF node
|
||||
AnfNodePtr left_node = ParseExprNode(block, left);
|
||||
if (left_node == nullptr) {
|
||||
MS_LOG(WARNING) << "DoBinOp process left node failed: " << errcode();
|
||||
return nullptr;
|
||||
MS_LOG(EXCEPTION) << "DoBinOp process left node failed: " << errcode();
|
||||
}
|
||||
AnfNodePtr right_node = ParseExprNode(block, right);
|
||||
if (right_node == nullptr) {
|
||||
MS_LOG(WARNING) << "DoBinOp process right node failed:" << errcode();
|
||||
return nullptr;
|
||||
MS_LOG(EXCEPTION) << "DoBinOp process right node failed:" << errcode();
|
||||
}
|
||||
// Resolve the op
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
AnfNodePtr op_node = block->MakeResolveAstOp(op);
|
||||
// Create apply node
|
||||
MS_EXCEPTION_IF_NULL(block->func_graph());
|
||||
return block->func_graph()->NewCNodeInOrder({op_node, left_node, right_node});
|
||||
}
|
||||
|
||||
|
@ -480,6 +500,7 @@ AnfNodePtr Parser::ParseName(const FunctionBlockPtr &block, const py::object &no
|
|||
MS_LOG(DEBUG) << "Process ast Name";
|
||||
auto name_id = py::cast<std::string>(python_adapter::GetPyObjAttr(node, "id"));
|
||||
MS_LOG(DEBUG) << "The Name id is " << name_id;
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
if (block->IsGlobalVar(name_id)) {
|
||||
return block->MakeResolveSymbol(name_id);
|
||||
}
|
||||
|
@ -509,9 +530,8 @@ AnfNodePtr Parser::ParseNum(const FunctionBlockPtr &, const py::object &node) {
|
|||
return NewValueNode(data);
|
||||
} else {
|
||||
// no else actually
|
||||
MS_LOG(ERROR) << "Unsupported Num type : " << (std::string)py::str(obj);
|
||||
errcode_ = PARSE_NODE_TYPE_UNKNOWN;
|
||||
return nullptr;
|
||||
MS_LOG(EXCEPTION) << "Unsupported Num type : " << (std::string)py::str(obj);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -558,14 +578,13 @@ AnfNodePtr Parser::ParseNameConstant(const FunctionBlockPtr &, const py::object
|
|||
} else if (py::isinstance<py::none>(obj)) {
|
||||
MS_LOG(INFO) << "The NameConstant is none:" << (std::string)py::str(obj);
|
||||
return NewValueNode(kNone);
|
||||
} else {
|
||||
// no else actually
|
||||
MS_LOG(ERROR) << "Unsupported NameConstant type: " << (std::string)py::str(obj);
|
||||
errcode_ = PARSE_NODE_TYPE_UNKNOWN;
|
||||
return nullptr;
|
||||
}
|
||||
// no else actually
|
||||
errcode_ = PARSE_NODE_TYPE_UNKNOWN;
|
||||
MS_LOG(EXCEPTION) << "Unsupported NameConstant type: " << (std::string)py::str(obj);
|
||||
}
|
||||
AnfNodePtr Parser::GenerateMakeTuple(const FunctionBlockPtr &block, const std::vector<AnfNodePtr> &element_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
AnfNodePtr make_tuple_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE);
|
||||
std::vector<AnfNodePtr> make_tuple_nodes;
|
||||
make_tuple_nodes.push_back(make_tuple_op);
|
||||
|
@ -575,6 +594,7 @@ AnfNodePtr Parser::GenerateMakeTuple(const FunctionBlockPtr &block, const std::v
|
|||
}
|
||||
|
||||
AnfNodePtr Parser::ParseSuper(const FunctionBlockPtr &block, const py::list &args) {
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
py::object father_class;
|
||||
const size_t expect_args_size = 2;
|
||||
if (args.empty()) {
|
||||
|
@ -588,7 +608,7 @@ AnfNodePtr Parser::ParseSuper(const FunctionBlockPtr &block, const py::list &arg
|
|||
} else {
|
||||
MS_EXCEPTION(ArgumentError) << "When call 'super', the args number should be 0 or 2, but got" << args.size() << ".";
|
||||
}
|
||||
py::object target_class_instance = ast()->CallParserObjMethod(PYTHON_PARSE_ANALYZE_SUPER, father_class, ast()->obj());
|
||||
py::object target_class_instance = ast_->CallParserObjMethod(PYTHON_PARSE_ANALYZE_SUPER, father_class, ast_->obj());
|
||||
py::object namespace_var = ast_->CallParseModFunction(PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, target_class_instance);
|
||||
NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var);
|
||||
SymbolPtr symbol = std::make_shared<Symbol>("namespace");
|
||||
|
@ -626,6 +646,7 @@ AnfNodePtr Parser::ParseCall(const FunctionBlockPtr &block, const py::object &no
|
|||
|
||||
CNodePtr MakeUnpackCall(const FuncGraphPtr &func_graph, const AnfNodePtr &call_function_anf_node,
|
||||
const std::vector<AnfNodePtr> &packed_arguments) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
std::vector<AnfNodePtr> unpack_call_nodes;
|
||||
auto unpack_call_op = NewValueNode(std::make_shared<prim::UnpackCall>(NAMED_METAGRAPH_UNPACKCALL));
|
||||
unpack_call_nodes.push_back(unpack_call_op);
|
||||
|
@ -640,6 +661,7 @@ AnfNodePtr Parser::GenerateAnfNodeForCall(const FunctionBlockPtr &block, const A
|
|||
const std::vector<AnfNodePtr> &packed_arguments,
|
||||
const std::vector<AnfNodePtr> &group_arguments, bool need_unpack) const {
|
||||
// If there is keyword arguments or starred, using an unpack_call op to unpack the argument
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
if (need_unpack) {
|
||||
return MakeUnpackCall(block->func_graph(), call_function_anf_node, packed_arguments);
|
||||
}
|
||||
|
@ -654,6 +676,8 @@ AnfNodePtr Parser::GenerateAnfNodeForCall(const FunctionBlockPtr &block, const A
|
|||
|
||||
bool Parser::ParseArgsInCall(const FunctionBlockPtr &block, const py::list &args,
|
||||
std::vector<AnfNodePtr> *packed_arguments, std::vector<AnfNodePtr> *group_arguments) {
|
||||
MS_EXCEPTION_IF_NULL(packed_arguments);
|
||||
MS_EXCEPTION_IF_NULL(group_arguments);
|
||||
bool need_unpack = false;
|
||||
for (size_t i = 0; i < args.size(); i++) {
|
||||
auto arg_node = AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, args[i])));
|
||||
|
@ -679,6 +703,8 @@ bool Parser::ParseKeywordsInCall(const FunctionBlockPtr &block, const py::object
|
|||
bool need_unpack = false;
|
||||
py::list keywords = python_adapter::GetPyObjAttr(node, "keywords");
|
||||
if (!keywords.empty()) {
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
MS_EXCEPTION_IF_NULL(packed_arguments);
|
||||
need_unpack = true;
|
||||
std::vector<AnfNodePtr> keys;
|
||||
std::vector<AnfNodePtr> values;
|
||||
|
@ -708,14 +734,14 @@ bool Parser::ParseKeywordsInCall(const FunctionBlockPtr &block, const py::object
|
|||
// Process call attributes of class type define, eg: x.y()
|
||||
AnfNodePtr Parser::ParseAttribute(const FunctionBlockPtr &block, const py::object &node) {
|
||||
MS_LOG(DEBUG) << "Process ast Attribute";
|
||||
|
||||
// Process class value,eg: self.xx
|
||||
if (ast()->target_type() == PARSE_TARGET_OBJECT_INSTANCE) {
|
||||
if (ast_->target_type() == PARSE_TARGET_OBJECT_INSTANCE) {
|
||||
if (ast_->IsClassMember(node)) {
|
||||
std::string var_name = "self.";
|
||||
std::string attr_name = node.attr("attr").cast<std::string>();
|
||||
(void)var_name.append(attr_name);
|
||||
auto attr_obj = ast()->obj().attr(attr_name.c_str());
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
if (py::hasattr(ast()->obj(), attr_name.c_str()) &&
|
||||
(py::hasattr(attr_obj, PYTHON_PRIMITIVE_FLAG) || py::isinstance<py::int_>(attr_obj) ||
|
||||
py::isinstance<py::float_>(attr_obj) || py::isinstance<py::bool_>(attr_obj) ||
|
||||
|
@ -736,8 +762,7 @@ AnfNodePtr Parser::ParseAttribute(const FunctionBlockPtr &block, const py::objec
|
|||
py::object value_body = python_adapter::GetPyObjAttr(node, "value");
|
||||
AnfNodePtr value_node = ParseExprNode(block, value_body);
|
||||
if (value_node == nullptr) {
|
||||
MS_LOG(WARNING) << "Parse attribute failed";
|
||||
return nullptr;
|
||||
MS_LOG(EXCEPTION) << "Parse attribute failed";
|
||||
}
|
||||
|
||||
// Process the node attr
|
||||
|
@ -761,24 +786,30 @@ AnfNodePtr Parser::ParseCompare(const FunctionBlockPtr &block, const py::object
|
|||
// For python comparison ,there may be if x>y>5 ,
|
||||
// Which there is two ops , but we only support one now
|
||||
py::list ops = python_adapter::GetPyObjAttr(node, "ops");
|
||||
if (ops.size() > MAX_COMPARISON_OPS_SUPPORTED) {
|
||||
MS_EXCEPTION(NotSupportError)
|
||||
<< "MindSpore does not support comparison with operators more than one now, ops size =" << ops.size();
|
||||
if (ops.size() != MAX_COMPARISON_OPS_SUPPORTED) {
|
||||
MS_EXCEPTION(NotSupportError) << "MindSpore only support comparison with operators with one now, ops size ="
|
||||
<< ops.size();
|
||||
}
|
||||
|
||||
py::object left = python_adapter::GetPyObjAttr(node, "left");
|
||||
py::list comparators = python_adapter::GetPyObjAttr(node, "comparators");
|
||||
if (comparators.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Comparators can't be empty.";
|
||||
}
|
||||
AnfNodePtr left_node = ParseExprNode(block, left);
|
||||
AnfNodePtr right_node = ParseExprNode(block, comparators[0]);
|
||||
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
AnfNodePtr op_node = block->MakeResolveAstOp(ops[0]);
|
||||
|
||||
return block->func_graph()->NewCNodeInOrder({op_node, left_node, right_node});
|
||||
}
|
||||
|
||||
AnfNodePtr Parser::ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, AstSubType mode) {
|
||||
// If there is only one bool op now
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
if (value_list.empty()) {
|
||||
MS_LOG(EXCEPTION) << "value list is empty.";
|
||||
}
|
||||
if (value_list.size() == 1) {
|
||||
AnfNodePtr first_node = ParseExprNode(block, value_list[0]);
|
||||
return first_node;
|
||||
|
@ -788,15 +819,15 @@ AnfNodePtr Parser::ProcessBoolOpValueList(const FunctionBlockPtr &block, const p
|
|||
for (size_t i = 1; i < value_list.size(); i++) {
|
||||
rest.append(value_list[i]);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
FunctionBlockPtr true_block = nullptr;
|
||||
FunctionBlockPtr false_block = nullptr;
|
||||
auto block_fg = block->func_graph();
|
||||
{
|
||||
TraceGuard guard(std::make_shared<TraceIfExpTrueBranch>(block->func_graph()->debug_info()));
|
||||
TraceGuard guard(std::make_shared<TraceIfExpTrueBranch>(block_fg->debug_info()));
|
||||
true_block = MakeFunctionBlock(*this);
|
||||
}
|
||||
{
|
||||
TraceGuard guard(std::make_shared<TraceIfExpFalseBranch>(block->func_graph()->debug_info()));
|
||||
TraceGuard guard(std::make_shared<TraceIfExpFalseBranch>(block_fg->debug_info()));
|
||||
false_block = MakeFunctionBlock(*this);
|
||||
}
|
||||
MakeConditionBlocks(block, true_block, false_block);
|
||||
|
@ -820,12 +851,12 @@ AnfNodePtr Parser::ProcessBoolOpValueList(const FunctionBlockPtr &block, const p
|
|||
b2->func_graph()->set_output(test_node);
|
||||
|
||||
auto cond_node = block->ForceToBoolNode(test_node);
|
||||
auto switch_app = block->func_graph()->NewCNodeInOrder({NewValueNode(prim::kPrimSwitch), cond_node,
|
||||
NewValueNode(true_block->func_graph()),
|
||||
NewValueNode(false_block->func_graph())});
|
||||
auto switch_app =
|
||||
block_fg->NewCNodeInOrder({NewValueNode(prim::kPrimSwitch), cond_node, NewValueNode(true_block->func_graph()),
|
||||
NewValueNode(false_block->func_graph())});
|
||||
|
||||
std::vector<AnfNodePtr> call_graph_nodes{switch_app};
|
||||
auto switch_app_call = block->func_graph()->NewCNodeInOrder(call_graph_nodes);
|
||||
auto switch_app_call = block_fg->NewCNodeInOrder(call_graph_nodes);
|
||||
return switch_app_call;
|
||||
}
|
||||
}
|
||||
|
@ -836,8 +867,7 @@ AnfNodePtr Parser::ParseBoolOp(const FunctionBlockPtr &block, const py::object &
|
|||
py::object op_node = python_adapter::GetPyObjAttr(node, "op");
|
||||
AstSubType op_type = ast_->GetOpType(op_node);
|
||||
if (op_type == AST_SUB_TYPE_UNKNOWN) {
|
||||
MS_LOG(WARNING) << "ProcessBoolOp, got unknown op type";
|
||||
return nullptr;
|
||||
MS_LOG(EXCEPTION) << "ProcessBoolOp, got unknown op type";
|
||||
}
|
||||
py::list op_values = python_adapter::GetPyObjAttr(node, "values");
|
||||
return ProcessBoolOpValueList(block, op_values, op_type);
|
||||
|
@ -866,20 +896,21 @@ AnfNodePtr Parser::ParseLambda(const FunctionBlockPtr &block, const py::object &
|
|||
|
||||
// Get lambda args
|
||||
py::list args = ast_->GetArgs(node);
|
||||
auto block_fg = func_block->func_graph();
|
||||
for (std::size_t i = 0; i < args.size(); i++) {
|
||||
std::string arg = py::cast<std::string>(args[i].attr("arg"));
|
||||
TraceGuard guard(GetLocation(args[i]));
|
||||
auto para_node = std::make_shared<Parameter>(func_block->func_graph());
|
||||
auto para_node = std::make_shared<Parameter>(block_fg);
|
||||
para_node->debug_info()->set_name(arg);
|
||||
func_block->func_graph()->add_parameter(para_node);
|
||||
block_fg->add_parameter(para_node);
|
||||
func_block->WriteVariable(arg, para_node);
|
||||
MS_LOG(DEBUG) << "The arg[" << i << "] is " << arg;
|
||||
}
|
||||
|
||||
py::object body_node = python_adapter::GetPyObjAttr(node, "body");
|
||||
AnfNodePtr lambda_body_node = ParseExprNode(func_block, body_node);
|
||||
func_block->func_graph()->set_output(lambda_body_node);
|
||||
ValueNodePtr const_graph = NewValueNode(func_block->func_graph());
|
||||
block_fg->set_output(lambda_body_node);
|
||||
ValueNodePtr const_graph = NewValueNode(block_fg);
|
||||
return const_graph;
|
||||
}
|
||||
|
||||
|
@ -888,7 +919,7 @@ AnfNodePtr Parser::ParseTuple(const FunctionBlockPtr &block, const py::object &n
|
|||
MS_LOG(DEBUG) << "Process ast Tuple";
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
py::tuple elts = python_adapter::GetPyObjAttr(node, "elts");
|
||||
if (elts.size() == 0) {
|
||||
if (elts.empty()) {
|
||||
auto empty_tuple = std::vector<ValuePtr>();
|
||||
return NewValueNode(std::make_shared<ValueTuple>(empty_tuple));
|
||||
}
|
||||
|
@ -909,7 +940,7 @@ AnfNodePtr Parser::ParseList(const FunctionBlockPtr &block, const py::object &no
|
|||
MS_LOG(DEBUG) << "Process ast List";
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
py::list elts = python_adapter::GetPyObjAttr(node, "elts");
|
||||
if (elts.size() == 0) {
|
||||
if (elts.empty()) {
|
||||
auto empty_list = std::vector<ValuePtr>();
|
||||
return NewValueNode(std::make_shared<ValueList>(empty_list));
|
||||
}
|
||||
|
@ -934,7 +965,6 @@ AnfNodePtr Parser::ParseSubscript(const FunctionBlockPtr &block, const py::objec
|
|||
py::object slice_node = python_adapter::GetPyObjAttr(node, "slice");
|
||||
AnfNodePtr value = ParseExprNode(block, value_node);
|
||||
AnfNodePtr slice = ParseExprNode(block, slice_node);
|
||||
|
||||
return block->func_graph()->NewCNodeInOrder({op_getitem, value, slice});
|
||||
}
|
||||
|
||||
|
@ -949,7 +979,6 @@ AnfNodePtr Parser::ParseSlice(const FunctionBlockPtr &block, const py::object &n
|
|||
AnfNodePtr start_node = ParseExprNode(block, start);
|
||||
AnfNodePtr stop_node = ParseExprNode(block, stop);
|
||||
AnfNodePtr step_node = ParseExprNode(block, step);
|
||||
|
||||
return block->func_graph()->NewCNodeInOrder({op_makeslice, start_node, stop_node, step_node});
|
||||
}
|
||||
|
||||
|
@ -1060,12 +1089,13 @@ FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object
|
|||
|
||||
FunctionBlockPtr true_block = nullptr;
|
||||
FunctionBlockPtr false_block = nullptr;
|
||||
auto block_fg = block->func_graph();
|
||||
{
|
||||
TraceGuard guard(std::make_shared<TraceIfStmtTrueBranch>(block->func_graph()->debug_info()));
|
||||
TraceGuard guard(std::make_shared<TraceIfStmtTrueBranch>(block_fg->debug_info()));
|
||||
true_block = MakeFunctionBlock(*this);
|
||||
}
|
||||
{
|
||||
TraceGuard guard(std::make_shared<TraceIfStmtFalseBranch>(block->func_graph()->debug_info()));
|
||||
TraceGuard guard(std::make_shared<TraceIfStmtFalseBranch>(block_fg->debug_info()));
|
||||
false_block = MakeFunctionBlock(*this);
|
||||
}
|
||||
|
||||
|
@ -1073,7 +1103,7 @@ FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object
|
|||
|
||||
FunctionBlockPtr after_block = nullptr;
|
||||
{
|
||||
TraceGuard guard(std::make_shared<TraceIfStmtAfterBranch>(block->func_graph()->debug_info()));
|
||||
TraceGuard guard(std::make_shared<TraceIfStmtAfterBranch>(block_fg->debug_info()));
|
||||
after_block = MakeFunctionBlock(*this);
|
||||
}
|
||||
|
||||
|
@ -1178,7 +1208,8 @@ FunctionBlockPtr Parser::GenerateBlockInFor(const TraceInfoPtr &trace_info) {
|
|||
|
||||
int64_t Parser::GetForTransToWhileLoop() {
|
||||
// int64 support 63bits positive num mostly.
|
||||
if (max_for_loop_count_str_.size() > 63 || max_for_loop_count_str_.empty()) {
|
||||
constexpr auto kMaxNumLength = 10;
|
||||
if (max_for_loop_count_str_.size() > kMaxNumLength || max_for_loop_count_str_.empty()) {
|
||||
return MAX_FOR_LOOP_COUNT;
|
||||
}
|
||||
if (std::any_of(max_for_loop_count_str_.begin(), max_for_loop_count_str_.end(),
|
||||
|
@ -1634,6 +1665,7 @@ FunctionBlockPtr Parser::ParsePass(const FunctionBlockPtr &block, const py::obje
|
|||
}
|
||||
|
||||
AnfNodePtr FindPhis(const std::unordered_map<ParameterPtr, AnfNodePtr> &removable_phis, const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
const auto &inp = node->cast<ParameterPtr>();
|
||||
const auto &iter = removable_phis.find(inp);
|
||||
if (iter == removable_phis.end()) {
|
||||
|
@ -1675,6 +1707,7 @@ void Parser::RemoveUnnecessaryPhis() {
|
|||
std::vector<AnfNodePtr> new_parameters(parameters.size());
|
||||
auto it = std::copy_if(
|
||||
parameters.begin(), parameters.end(), new_parameters.begin(), [&local_removable_phis](const AnfNodePtr ¶m) {
|
||||
MS_EXCEPTION_IF_NULL(param);
|
||||
return local_removable_phis.find(param->cast<ParameterPtr>()) == local_removable_phis.end();
|
||||
});
|
||||
|
||||
|
@ -1762,8 +1795,7 @@ AstNodeTypePtr ParseAst::GetNodeType(const py::object &node) {
|
|||
py::list list_value = python_adapter::CallPyModFn(module_, PYTHON_PARSE_GET_NODE_TYPE, node);
|
||||
const size_t list_value_size = 2;
|
||||
if (list_value.size() < list_value_size) {
|
||||
MS_LOG(ERROR) << "The node of python method must has 2 values.";
|
||||
return nullptr;
|
||||
MS_LOG(EXCEPTION) << "The node of python method must has 2 values.";
|
||||
}
|
||||
auto node_name = py::cast<std::string>(list_value[0]);
|
||||
auto type = AstMainType(py::cast<int32_t>(list_value[1]));
|
||||
|
@ -1828,6 +1860,7 @@ static AnfNodePtr CopyNodesFromParamDefaultValue(const FuncGraphPtr &func_graph,
|
|||
std::size_t index = 0;
|
||||
std::vector<AnfNodePtr> old_cnodes;
|
||||
old_cnodes.emplace_back(param_node);
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto res = func_graph->NewCNodeInOrder({});
|
||||
std::vector<CNodePtr> new_cnodes;
|
||||
new_cnodes.emplace_back(res);
|
||||
|
@ -1835,7 +1868,6 @@ static AnfNodePtr CopyNodesFromParamDefaultValue(const FuncGraphPtr &func_graph,
|
|||
auto current = old_cnodes[index];
|
||||
auto current_new_cnode = new_cnodes[index];
|
||||
index++;
|
||||
MS_EXCEPTION_IF_NULL(current);
|
||||
if (current->isa<CNode>()) {
|
||||
auto &inputs = current->cast<CNodePtr>()->inputs();
|
||||
for (auto it = inputs.begin(); it != inputs.end(); it++) {
|
||||
|
|
|
@ -63,6 +63,8 @@ using mindspore::abstract::AnalysisContextPtr;
|
|||
using mindspore::validator::Validate;
|
||||
namespace {
|
||||
void DoRenormalize(const bool &changed, const FuncGraphPtr &func_graph, const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
abstract::AbstractBasePtrList args_spec;
|
||||
auto parameters = func_graph->parameters();
|
||||
(void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec),
|
||||
|
@ -76,15 +78,16 @@ void DoRenormalize(const bool &changed, const FuncGraphPtr &func_graph, const Re
|
|||
} // namespace
|
||||
|
||||
bool SimplifyDataStructuresPass(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res->func_graph());
|
||||
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
FuncGraphPtr func_graph = res->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
bool changed = opt::SimplifyDataStructures(func_graph, res->manager());
|
||||
DoRenormalize(changed, func_graph, res);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TransformTopGraphPass(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
if (res->func_graph() == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Transform top graph error.";
|
||||
}
|
||||
|
@ -103,15 +106,16 @@ bool TransformTopGraphPass(const ResourcePtr &res) {
|
|||
}
|
||||
|
||||
bool CleanAfterOptAPass(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res->func_graph());
|
||||
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
FuncGraphPtr func_graph = res->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
bool changed = opt::CleanAfterOptA(func_graph, res->manager());
|
||||
DoRenormalize(changed, func_graph, res);
|
||||
return true;
|
||||
}
|
||||
|
||||
FuncGraphPtr PrimBpOptPassStep1(const opt::irpass::OptimizeIRPassLib &irpass, const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
MS_EXCEPTION_IF_NULL(res->func_graph());
|
||||
opt::OptPassConfig pynative_eliminate = opt::OptPassConfig({
|
||||
irpass.pynative_eliminate_,
|
||||
|
@ -137,6 +141,7 @@ FuncGraphPtr PrimBpOptPassStep1(const opt::irpass::OptimizeIRPassLib &irpass, co
|
|||
}
|
||||
|
||||
FuncGraphPtr PrimBpOptPassStep2(const opt::irpass::OptimizeIRPassLib &irpass, const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
MS_EXCEPTION_IF_NULL(res->func_graph());
|
||||
opt::OptPassConfig special_op_simplify = opt::OptPassConfig({
|
||||
irpass.switch_simplify_,
|
||||
|
@ -513,6 +518,7 @@ void ReclaimOptimizer() {
|
|||
}
|
||||
|
||||
bool OptPassGroup(const ResourcePtr &res, const std::string &name) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
if (res->func_graph() == nullptr) {
|
||||
MS_LOG(ERROR) << "Opt passes int64_t error";
|
||||
return false;
|
||||
|
@ -551,6 +557,7 @@ bool AddRecomputationPass(const ResourcePtr &res) {
|
|||
}
|
||||
|
||||
bool AddCacheEmbeddingPass(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
if (ps::PSContext::instance()->is_ps_mode()) {
|
||||
return true;
|
||||
|
@ -571,6 +578,7 @@ bool AddCacheEmbeddingPass(const ResourcePtr &res) {
|
|||
}
|
||||
|
||||
bool RemoveValueNodeDuplicationsPass(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
if (res->func_graph() == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Remove value node duplications error.";
|
||||
}
|
||||
|
@ -588,6 +596,7 @@ bool RemoveValueNodeDuplicationsPass(const ResourcePtr &res) {
|
|||
}
|
||||
|
||||
bool CconvPass(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
MS_EXCEPTION_IF_NULL(res->func_graph());
|
||||
FuncGraphPtr func_graph = res->func_graph();
|
||||
FuncGraphPtr new_fg = LiftingClone(func_graph);
|
||||
|
@ -602,6 +611,7 @@ void UpdateFuncGraphParameter(const FuncGraphPtr &func_graph) {
|
|||
std::vector<AnfNodePtr> new_paras;
|
||||
for (const auto ¶m : func_graph->parameters()) {
|
||||
auto param_node = param->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(param_node);
|
||||
if (param_node->has_default()) {
|
||||
new_paras.push_back(param_node);
|
||||
continue;
|
||||
|
@ -618,6 +628,7 @@ void UpdateFuncGraphParameter(const FuncGraphPtr &func_graph) {
|
|||
}
|
||||
|
||||
bool ValidatePass(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
MS_EXCEPTION_IF_NULL(res->func_graph());
|
||||
FuncGraphPtr func_graph = res->func_graph();
|
||||
Validate(func_graph);
|
||||
|
|
|
@ -143,6 +143,7 @@ std::string GetCompileExceptionInfo() {
|
|||
}
|
||||
|
||||
void SetGpuLoopSink(const ResourcePtr &resource) {
|
||||
MS_EXCEPTION_IF_NULL(resource);
|
||||
auto func_graph = resource->func_graph();
|
||||
if (func_graph != nullptr && func_graph->manager() != nullptr) {
|
||||
auto manager = func_graph->manager();
|
||||
|
@ -452,13 +453,17 @@ ExecutorPy::~ExecutorPy() {
|
|||
|
||||
void ExecutorPy::GetWeightInfo(const CNodePtr &root_node, const AnfNodePtr &weight_node,
|
||||
std::map<std::string, std::pair<PrimitivePyAdapterPtr, std::string>> *fake_quant_table) {
|
||||
MS_EXCEPTION_IF_NULL(root_node);
|
||||
MS_EXCEPTION_IF_NULL(fake_quant_table);
|
||||
std::string weight_name;
|
||||
auto x = root_node->input(1);
|
||||
MS_EXCEPTION_IF_NULL(x);
|
||||
if (IsPrimitiveCNode(weight_node, prim::kPrimLoad)) {
|
||||
weight_name = weight_node->cast<CNodePtr>()->input(1)->cast<ParameterPtr>()->name();
|
||||
} else {
|
||||
weight_name = weight_node->cast<ParameterPtr>()->name();
|
||||
auto para = weight_node->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(para);
|
||||
weight_name = para->name();
|
||||
}
|
||||
// find the fakequant from input
|
||||
int64_t count = 0;
|
||||
|
@ -501,9 +506,12 @@ void ExecutorPy::GetWeightInfo(const CNodePtr &root_node, const AnfNodePtr &weig
|
|||
if (IsPrimitiveCNode(fakequant_min_node, prim::kPrimLoad)) {
|
||||
fakequant_min_node_name = fakequant_min_node->cast<CNodePtr>()->input(1)->cast<ParameterPtr>()->name();
|
||||
} else {
|
||||
fakequant_min_node_name = fakequant_min_node->cast<ParameterPtr>()->name();
|
||||
auto param = fakequant_min_node->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(param);
|
||||
fakequant_min_node_name = param->name();
|
||||
}
|
||||
auto quant_op_value = cnode->input(0)->cast<ValueNodePtr>()->value();
|
||||
MS_EXCEPTION_IF_NULL(quant_op_value);
|
||||
if (!quant_op_value->isa<PrimitivePy>()) {
|
||||
return;
|
||||
}
|
||||
|
@ -602,6 +610,7 @@ bool IsPhaseTrain(const std::string &phase_s) {
|
|||
}
|
||||
|
||||
std::vector<ActionItem> GetPipeline(const ResourcePtr &resource, const std::string &phase_s, bool use_vm) {
|
||||
MS_EXCEPTION_IF_NULL(resource);
|
||||
bool is_air = IsPhaseExportAir(phase_s);
|
||||
|
||||
std::string backend = MsContext::GetInstance()->backend_policy();
|
||||
|
|
|
@ -69,6 +69,7 @@ static int64_t InferStage(int64_t rank_id, int64_t stage_num, int64_t device_num
|
|||
|
||||
// Only auto_parallel and semi_auto_parallel support PipelineSplit
|
||||
bool PipelineSplit(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
auto parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
|
||||
if (parallel_mode != parallel::SEMI_AUTO_PARALLEL && parallel_mode != parallel::AUTO_PARALLEL) {
|
||||
MS_LOG(INFO) << "Only auto_parallel and semi_auto_parallel support pipeline split.";
|
||||
|
|
|
@ -54,6 +54,7 @@ void PrimBpropOptGraphLevel2Info::TryFreeOneValue(const ValuePtrList &op_args,
|
|||
}
|
||||
|
||||
void PrimBpropOptGraphLevel2Info::AnalysisArgUsingInfo(const FuncGraphManagerPtr &manager) {
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
if (analysis_finish_flg_) {
|
||||
return;
|
||||
}
|
||||
|
@ -106,7 +107,9 @@ void PrimBpropOptGraphLevel2Info::AalysisForTupleGetItem(const NodeUsersMap &nod
|
|||
const std::shared_ptr<AnfNode> ¶m,
|
||||
ParamUsingInfo *arg_info, const AnfNodePtr &user_node) const {
|
||||
MS_EXCEPTION_IF_NULL(arg_info);
|
||||
MS_EXCEPTION_IF_NULL(user_node);
|
||||
auto cnode = user_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
const size_t tuple_get_item_size = 3;
|
||||
const size_t index = 2;
|
||||
if (cnode->size() != tuple_get_item_size) {
|
||||
|
@ -140,6 +143,7 @@ void PrimBpropOptGraphLevel2Info::AalysisForTupleGetItem(const NodeUsersMap &nod
|
|||
void PrimBpropOptGraphLevel2Info::ArgInfoRefresh(const std::shared_ptr<AnfNode> ¶m,
|
||||
ParamUsingInfo *arg_info) const {
|
||||
MS_EXCEPTION_IF_NULL(arg_info);
|
||||
MS_EXCEPTION_IF_NULL(param);
|
||||
auto abs = param->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
if (abs->isa<abstract::AbstractTensor>()) {
|
||||
|
@ -201,6 +205,7 @@ FuncGraphPtr PrimBpropOptimizer::OptimizeBPropFuncGraph(const FuncGraphPtr &bpro
|
|||
|
||||
FuncGraphPtr PrimBpropOptimizer::GetOptBpropFromCache(const FuncGraphPtr &bprop_fg, const ValuePtrList &op_args,
|
||||
const ValuePtr &out, const PrimitivePtr &prim) {
|
||||
MS_EXCEPTION_IF_NULL(bprop_fg);
|
||||
abstract::AbstractBasePtrList abs_list;
|
||||
ArgsToAbs(prim, op_args, &abs_list);
|
||||
|
||||
|
@ -233,6 +238,7 @@ FuncGraphPtr PrimBpropOptimizer::GetOptBpropFromCache(const FuncGraphPtr &bprop_
|
|||
|
||||
FuncGraphPtr PrimBpropOptimizer::GenSpecOptBprop(const FuncGraphPtr &bprop_fg, const ValuePtrList &op_args,
|
||||
const ValuePtr &out, const PrimitivePtr &prim, bool hook_flg) {
|
||||
MS_EXCEPTION_IF_NULL(bprop_fg);
|
||||
abstract::AbstractBasePtrList abs_list;
|
||||
ArgsToAbs(prim, op_args, &abs_list);
|
||||
if (!hook_flg) {
|
||||
|
@ -272,6 +278,7 @@ PrimBpropOptGraphInfoPtr PrimBpropOptimizer::PrimBpropOptStep1(const FuncGraphPt
|
|||
|
||||
void PrimBpropOptimizer::BindAbsToParameters(const FuncGraphPtr &bprop_fg,
|
||||
const abstract::AbstractBasePtrList &abs_list_input) {
|
||||
MS_EXCEPTION_IF_NULL(bprop_fg);
|
||||
auto ¶ms = bprop_fg->parameters();
|
||||
if (abs_list_input.size() != params.size()) {
|
||||
MS_LOG(EXCEPTION) << "Param num:" << params.size() << " not match inputs num " << abs_list_input.size();
|
||||
|
@ -306,6 +313,9 @@ ECacheQrtRes PrimBpropOptimizer::GetOptBpfgFromCache(const PrimitivePtr &prim,
|
|||
const abstract::AbstractBasePtrList &abs_list,
|
||||
PrimBpropOptGraphLevel2InfoPtr *level_2_graph_info,
|
||||
PrimBpropOptGraphInfoPtr *level_1_graph_info) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
MS_EXCEPTION_IF_NULL(level_1_graph_info);
|
||||
MS_EXCEPTION_IF_NULL(level_2_graph_info);
|
||||
auto attrs_ = prim->attrs();
|
||||
for (auto &item : attrs_) {
|
||||
MS_LOG(DEBUG) << "prim:" << prim->ToString() << " attr: " << item.first << " value:" << item.second->ToString();
|
||||
|
@ -327,6 +337,8 @@ ECacheQrtRes PrimBpropOptimizer::GetOptBpfgFromCache(const PrimitivePtr &prim,
|
|||
|
||||
void PrimBpropOptimizer::ArgsToAbs(const PrimitivePtr &prim, const ValuePtrList &op_args,
|
||||
abstract::AbstractBasePtrList *abs_list) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
MS_EXCEPTION_IF_NULL(abs_list);
|
||||
auto const_input_index = prim->get_const_input_indexes();
|
||||
bool have_const_input = !const_input_index.empty();
|
||||
bool is_const_prim = prim->is_const_prim();
|
||||
|
@ -345,6 +357,7 @@ void PrimBpropOptimizer::ArgsToAbs(const PrimitivePtr &prim, const ValuePtrList
|
|||
|
||||
abstract::AbstractBasePtrList PrimBpropOptimizer::AddOutToAbsList(const ValuePtr &out,
|
||||
const abstract::AbstractBasePtrList &abs_list) {
|
||||
MS_EXCEPTION_IF_NULL(out);
|
||||
if (!out->isa<tensor::Tensor>() && !out->isa<ValueTuple>()) {
|
||||
MS_LOG(EXCEPTION) << "Out value not Tensor or Tuple, please check the input arguments.";
|
||||
}
|
||||
|
|
|
@ -29,6 +29,8 @@ namespace mindspore {
|
|||
namespace pipeline {
|
||||
void TryToDoReplace(FuncGraphManager *const manager, const AnfNodePtr &node, HashCache *const hash_cache,
|
||||
HashValue *const hash_value) {
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
MS_EXCEPTION_IF_NULL(hash_cache);
|
||||
const auto &to_check_value = GetValueNode(node);
|
||||
MS_EXCEPTION_IF_NULL(to_check_value);
|
||||
|
||||
|
|
|
@ -72,8 +72,10 @@ void AnalysisSchedule::SetNextRunnableImpl() {
|
|||
return;
|
||||
}
|
||||
// Check if enter endless loop
|
||||
auto it = std::find_if(asyncAbstractList_.begin(), asyncAbstractList_.end(),
|
||||
[](const auto &item) { return item->HasResult(); });
|
||||
auto it = std::find_if(asyncAbstractList_.begin(), asyncAbstractList_.end(), [](const auto &item) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
return item->HasResult();
|
||||
});
|
||||
if (it == asyncAbstractList_.end()) {
|
||||
// Enter endless loop if there is not ready result.
|
||||
MS_LOG(EXCEPTION) << "Enter endless loop. There isn't any branch that can been evaluated. Please check the code.";
|
||||
|
@ -184,6 +186,7 @@ void AnalysisResultCacheMgr::Todo() {
|
|||
std::lock_guard<std::mutex> lock(todo_lock_);
|
||||
while (!todo_.empty()) {
|
||||
AnfNodeConfigPtr conf = todo_.front();
|
||||
MS_EXCEPTION_IF_NULL(conf);
|
||||
todo_.pop_front();
|
||||
if (GetValue(conf) == nullptr) {
|
||||
MS_LOG(INFO) << conf->node()->ToString() << " not in globle cache.";
|
||||
|
@ -193,11 +196,14 @@ void AnalysisResultCacheMgr::Todo() {
|
|||
MS_LOG(INFO) << conf->node()->ToString() << " not in switch cache";
|
||||
continue;
|
||||
}
|
||||
if (!(*GetValue(conf)->abstract() == *TryGetSwitchValue(conf))) {
|
||||
auto switch_value = TryGetSwitchValue(conf);
|
||||
auto abstract = GetValue(conf)->abstract();
|
||||
MS_EXCEPTION_IF_NULL(switch_value);
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
if (!(*abstract == *switch_value)) {
|
||||
MS_LOG(WARNING) << " Switch Value is not eq. "
|
||||
<< " switch cache: " << TryGetSwitchValue(conf)->ToString()
|
||||
<< " globle cache: " << GetValue(conf)->abstract()->ToString()
|
||||
<< "\tconf: " << conf->node()->ToString();
|
||||
<< " switchCache: " << switch_value->ToString() << " globleCache: " << abstract->ToString()
|
||||
<< "\t\tConf: " << conf->ToString();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -41,6 +41,7 @@ using RefInputs = OrderedMap<AnfNodePtr, std::vector<size_t>>;
|
|||
// Add or get a monad parameter.
|
||||
AnfNodePtr AddMonadParameter(const FuncGraphPtr &func_graph, const std::string &name,
|
||||
const abstract::AbstractBasePtr &abs) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
size_t params_size = func_graph->parameters().size();
|
||||
size_t io_monad_location = params_size;
|
||||
// Search for existed parameters, return it if found.
|
||||
|
@ -109,6 +110,7 @@ bool HasAbstractRef(const AnfNodePtr &node) {
|
|||
// Gets ref inputs and its indexes from a cnode.
|
||||
RefInputs GetRefInputs(const CNodePtr &cnode) {
|
||||
RefInputs ref_inputs;
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
for (size_t i = 1; i < cnode->size(); ++i) {
|
||||
auto &input = cnode->inputs().at(i);
|
||||
if (HasAbstractRef(input)) {
|
||||
|
@ -231,6 +233,7 @@ class SccFinder {
|
|||
// Search SCCs from the given graph.
|
||||
const State &Search(FuncGraphPtr graph) {
|
||||
// Create graph state, set it as visited.
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto [inserted, ok] = visited_.emplace(graph, State(index_++));
|
||||
if (!ok) {
|
||||
MS_LOG(EXCEPTION) << "Already visited: " << graph->ToString();
|
||||
|
@ -336,6 +339,7 @@ class SideEffectFinder {
|
|||
}
|
||||
|
||||
static void UpdateOrderList(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
OrderedSet<CNodePtr> new_order_list;
|
||||
const auto &order_list = func_graph->order_list();
|
||||
for (auto &cnode : order_list) {
|
||||
|
@ -368,11 +372,13 @@ class SideEffectFinder {
|
|||
|
||||
// Gets branch graph from a switch cnode at given input index.
|
||||
FuncGraphPtr GetSwitchBranch(const CNodePtr &cnode, size_t index) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
return GetValueNode<FuncGraphPtr>(cnode->inputs().at(index));
|
||||
}
|
||||
|
||||
// Gets branch graphs from a switch cnode.
|
||||
std::vector<FuncGraphPtr> GetSwitchBranches(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
constexpr size_t switch_cnode_size = 4;
|
||||
constexpr size_t true_index = 2;
|
||||
constexpr size_t false_index = 3;
|
||||
|
@ -457,6 +463,7 @@ class SideEffectFinder {
|
|||
|
||||
// Gets branch graphs from a switch_layer cnode.
|
||||
std::vector<FuncGraphPtr> GetSwitchLayerBranches(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
constexpr size_t func_tuple_index = 2;
|
||||
if (cnode->size() <= func_tuple_index) {
|
||||
MS_LOG(EXCEPTION) << "Invalid switch_layer: " << cnode->DebugString(2);
|
||||
|
@ -485,11 +492,12 @@ class SideEffectFinder {
|
|||
if (func_graph != nullptr) {
|
||||
return GetGraphsFromTuple(func_graph->output());
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Invalid input for switch_layer: " << func_tuple->DebugString(2);
|
||||
MS_LOG(EXCEPTION) << "Invalid input for switch_layer: func_graph is nullptr.";
|
||||
}
|
||||
|
||||
// Get graphs from a tuple of funcs make node for switch_layer.
|
||||
std::vector<FuncGraphPtr> GetGraphsFromMakeTuple(const CNodePtr &make_tuple) {
|
||||
MS_EXCEPTION_IF_NULL(make_tuple);
|
||||
auto &inputs = make_tuple->inputs();
|
||||
if (inputs.size() <= 1) {
|
||||
MS_LOG(EXCEPTION) << "Invalid make_tuple for switch_layer: " << make_tuple->DebugString(2);
|
||||
|
@ -531,6 +539,7 @@ class SideEffectFinder {
|
|||
}
|
||||
|
||||
EffectInfo TraceTupleEffectInfo(const AnfNodePtr &tuple_node, std::stack<int64_t> *tuple_indexes) {
|
||||
MS_EXCEPTION_IF_NULL(tuple_indexes);
|
||||
auto para = dyn_cast<Parameter>(tuple_node);
|
||||
if (para != nullptr) {
|
||||
return TraceTupleParaEffectInfo(para, *tuple_indexes);
|
||||
|
@ -540,7 +549,7 @@ class SideEffectFinder {
|
|||
return TraceTupleCNodeEffectInfo(tuple_cnode, tuple_indexes);
|
||||
}
|
||||
// Should not reach here.
|
||||
MS_LOG(EXCEPTION) << "Side effects untraceable: " << tuple_node->DebugString();
|
||||
MS_LOG(EXCEPTION) << "Side effects untraceable: tuple_cnode is nullptr.";
|
||||
}
|
||||
|
||||
EffectInfo TraceTupleParaEffectInfo(const ParameterPtr ¶, const std::stack<int64_t> &tuple_indexes) {
|
||||
|
@ -556,6 +565,7 @@ class SideEffectFinder {
|
|||
|
||||
EffectInfo TraceTupleCNodeEffectInfo(const CNodePtr &cnode, std::stack<int64_t> *tuple_indexes) {
|
||||
MS_EXCEPTION_IF_NULL(tuple_indexes);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto prim = GetPrimitive(cnode);
|
||||
// Trace MakeTuple.
|
||||
if (IsPrimitiveEquals(prim, prim::kPrimMakeTuple)) {
|
||||
|
@ -640,6 +650,7 @@ class SideEffectFinder {
|
|||
}
|
||||
// Set merged effect info to both branches.
|
||||
for (auto &branch : branches) {
|
||||
MS_EXCEPTION_IF_NULL(branch);
|
||||
branch->SetEffectInfo(info);
|
||||
// Update caller if it is existed.
|
||||
UpdateBranchCaller(branch);
|
||||
|
@ -650,6 +661,7 @@ class SideEffectFinder {
|
|||
EffectInfo MergeEffectInfo(const std::vector<FuncGraphPtr> &branches) {
|
||||
EffectInfo info = {EffectInfo::kDetected, false, false, false};
|
||||
for (auto &branch : branches) {
|
||||
MS_EXCEPTION_IF_NULL(branch);
|
||||
EffectInfo branch_info = GetEffectInfo(branch);
|
||||
info.Merge(branch_info);
|
||||
}
|
||||
|
@ -658,6 +670,7 @@ class SideEffectFinder {
|
|||
|
||||
// Trace a cnode for effect info.
|
||||
EffectInfo TraceEffectInfo(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto prim = GetPrimitive(cnode);
|
||||
if (IsPrimitiveEquals(prim, prim::kPrimSwitch)) {
|
||||
// Special handling for Switch primitive.
|
||||
|
@ -740,7 +753,7 @@ class SideEffectFinder {
|
|||
}
|
||||
}
|
||||
// Something is wrong if we reached here.
|
||||
MS_LOG(WARNING) << "EffectInfo untraceable: " << node->DebugString(2);
|
||||
MS_LOG(WARNING) << "EffectInfo untraceable: node is a nullptr.";
|
||||
return {EffectInfo::kDetected, false, false, false};
|
||||
}
|
||||
|
||||
|
@ -767,6 +780,7 @@ class SideEffectFinder {
|
|||
}
|
||||
|
||||
void ForEachRealArguments(const ParameterPtr ¶, const std::function<void(const AnfNodePtr &)> &handler) {
|
||||
MS_EXCEPTION_IF_NULL(para);
|
||||
auto func_graph = para->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
// Find index of the parameter, starts from 0.
|
||||
|
@ -785,6 +799,7 @@ class SideEffectFinder {
|
|||
}
|
||||
// Caller cnode.
|
||||
auto cnode = dyn_cast<CNode>(user.first->first);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode && input_index < cnode->size()) {
|
||||
handler(cnode->input(input_index));
|
||||
}
|
||||
|
@ -793,6 +808,7 @@ class SideEffectFinder {
|
|||
|
||||
// For call node, returns effect info of the callee graph.
|
||||
EffectInfo GetCallEffectInfo(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
constexpr size_t min_call_node_size = 2;
|
||||
if (cnode->size() < min_call_node_size) {
|
||||
MS_LOG(EXCEPTION) << "Invalid call node: " << cnode->DebugString();
|
||||
|
@ -886,20 +902,23 @@ class SideEffectFinder {
|
|||
const SccPtr &GetScc(const FuncGraphPtr &func_graph) const {
|
||||
auto found = scc_map_.find(func_graph);
|
||||
if (found == scc_map_.end()) {
|
||||
MS_LOG(EXCEPTION) << "SCC not found for " << func_graph->ToString();
|
||||
MS_LOG(EXCEPTION) << "SCC not found for " << (func_graph ? func_graph->ToString() : "FG(null)");
|
||||
}
|
||||
return found->second;
|
||||
}
|
||||
|
||||
// Set effect info for all member graphs in the SCC.
|
||||
void SetSccEffectInfo(const SccPtr &scc, const EffectInfo &info) const {
|
||||
MS_EXCEPTION_IF_NULL(scc);
|
||||
for (auto &g : *scc) {
|
||||
MS_EXCEPTION_IF_NULL(g);
|
||||
g->SetEffectInfo(info);
|
||||
}
|
||||
}
|
||||
|
||||
// Gets EffectInfo for func graph.
|
||||
EffectInfo GetEffectInfo(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
const auto &effect_info = func_graph->GetEffectInfo();
|
||||
if (effect_info.state != EffectInfo::kUnknown) {
|
||||
// Effect info already set, return it.
|
||||
|
@ -907,6 +926,7 @@ class SideEffectFinder {
|
|||
}
|
||||
// Get SCC that this graph belongs to.
|
||||
auto &scc = GetScc(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(scc);
|
||||
// To prevent SCC members be visited again, we set effect info
|
||||
// to 'kDetecting' state before start to check cnodes.
|
||||
EffectInfo info{EffectInfo::kDetecting, false, false, false};
|
||||
|
@ -914,6 +934,7 @@ class SideEffectFinder {
|
|||
// Check side effects for all cnodes in the SCC.
|
||||
std::vector<CNodePtr> undetected;
|
||||
for (auto &g : *scc) {
|
||||
MS_EXCEPTION_IF_NULL(g);
|
||||
for (auto &cnode : g->order_list()) {
|
||||
auto cnode_effect = GetEffectInfo(cnode);
|
||||
if (cnode_effect.state != EffectInfo::kDetected) {
|
||||
|
@ -935,6 +956,7 @@ class SideEffectFinder {
|
|||
SetSccEffectInfo(scc, info);
|
||||
// Check undetected cnodes again after side effect of the SCC is detected.
|
||||
for (auto &cnode : undetected) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto cnode_effect = GetEffectInfo(cnode);
|
||||
// Side effect should be detected now.
|
||||
if (cnode_effect.state != EffectInfo::kDetected) {
|
||||
|
@ -951,6 +973,8 @@ class SideEffectFinder {
|
|||
}
|
||||
|
||||
void SaveBranchCaller(const CNodePtr &switch_node, const FuncGraphPtr &branch) {
|
||||
MS_EXCEPTION_IF_NULL(branch);
|
||||
MS_EXCEPTION_IF_NULL(switch_node);
|
||||
auto manager = branch->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
auto &node_users = manager->node_users();
|
||||
|
@ -971,6 +995,7 @@ class SideEffectFinder {
|
|||
}
|
||||
|
||||
void UpdateBranchCaller(const FuncGraphPtr &branch) {
|
||||
MS_EXCEPTION_IF_NULL(branch);
|
||||
auto iter = branch_caller_map.find(branch);
|
||||
if (iter == branch_caller_map.end()) {
|
||||
return;
|
||||
|
@ -992,6 +1017,8 @@ class SideEffectFinder {
|
|||
}
|
||||
|
||||
void AddMonadArgument(const CNodePtr &cnode, const ValuePtr &monad) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(monad);
|
||||
auto monad_abs = monad->ToAbstract();
|
||||
for (size_t i = 1; i < cnode->size(); ++i) {
|
||||
auto abs = cnode->inputs().at(i)->abstract();
|
||||
|
@ -1077,6 +1104,7 @@ class AutoMonadConverter {
|
|||
|
||||
// Gets effect info for a cnode.
|
||||
const EffectInfo &GetEffectInfo(const CNodePtr &cnode) const {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto &effect_info = cnode->GetEffectInfo();
|
||||
if (effect_info.state != EffectInfo::kDetected) {
|
||||
// Effect info should have been set by SideEffectFinder.
|
||||
|
@ -1135,6 +1163,7 @@ class AutoMonadConverter {
|
|||
}
|
||||
|
||||
void HandleOuterNode(const CNodePtr &cnode, const EffectInfo &info) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (info.memory || info.load) {
|
||||
(void)GetUniverse();
|
||||
bool load_with_primitive = (info.load && IsPrimitiveCNode(cnode));
|
||||
|
@ -1187,6 +1216,7 @@ class AutoMonadConverter {
|
|||
}
|
||||
|
||||
void HandleLoad(const CNodePtr &cnode, bool update_state) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto value = GetValueNode(cnode->input(0));
|
||||
if (value && value->isa<Primitive>()) {
|
||||
// For primitive calls that use Ref as input, insert Loads before them.
|
||||
|
@ -1273,6 +1303,7 @@ class AutoMonadConverter {
|
|||
|
||||
// Add or replace monad input.
|
||||
void AddMonadInput(const CNodePtr &cnode, const AnfNodePtr &monad) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
constexpr size_t max_monad_inputs = 2;
|
||||
auto monad_abs = monad->abstract();
|
||||
auto &inputs = cnode->inputs();
|
||||
|
@ -1332,6 +1363,7 @@ class AutoMonadConverter {
|
|||
}
|
||||
|
||||
AnfNodePtr UpdateState(const AnfNodePtr &state, const AnfNodePtr &attach) {
|
||||
MS_EXCEPTION_IF_NULL(attach);
|
||||
// Not attach UpdateState if set kAttrIgnoreSideEffect.
|
||||
auto attr_ignore_side_effect = attach->cast<CNodePtr>()->GetAttr(kAttrIgnoreSideEffect);
|
||||
auto ignore_side_effect = attr_ignore_side_effect != nullptr && attr_ignore_side_effect->isa<BoolImm>() &&
|
||||
|
@ -1453,6 +1485,7 @@ bool AutoMonad(const FuncGraphPtr &func_graph) {
|
|||
bool ReAutoMonad(const FuncGraphPtr &func_graph) {
|
||||
// AutoMonad for bprop network, only Monad for func graphs which back propogators have side effects.
|
||||
// Or AutoMonad for MultitypeFuncGraph which specialized in Renormalize other than the first Specialize pass.
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
bool need_auto_monad = false;
|
||||
std::vector<FuncGraphPtr> auto_monaded_fg;
|
||||
func_graph->EraseUnusedNodeInOrder();
|
||||
|
|
|
@ -38,7 +38,8 @@ string EvalEntryLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList
|
|||
ss << "Evaluator " << evaluator->ToString() << " run for " << out_conf->node()->scope()->name();
|
||||
}
|
||||
for (size_t i = 0; i < arg_spec_list.size(); i++) {
|
||||
ss << evaluator->ToString() << " input[" << i << "] abstract value: " << arg_spec_list[i]->ToString();
|
||||
ss << evaluator->ToString() << " input[" << i
|
||||
<< "] abstract value: " << (arg_spec_list[i] ? arg_spec_list[i]->ToString() : "null abstract.");
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
@ -60,6 +61,9 @@ void EvalFailLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList &,
|
|||
|
||||
void BaseFuncGraphEvaluator::EnterStackFrame(const AnalysisEnginePtr &engine, const StackFramePtr ¤t_stack_frame,
|
||||
const StackFramePtr &new_stack_frame) {
|
||||
MS_EXCEPTION_IF_NULL(current_stack_frame);
|
||||
MS_EXCEPTION_IF_NULL(new_stack_frame);
|
||||
MS_EXCEPTION_IF_NULL(engine);
|
||||
// Enter new func graph.
|
||||
auto ¤t_node = current_stack_frame->CurrentNode();
|
||||
auto current_context = current_stack_frame->current_context();
|
||||
|
@ -83,8 +87,8 @@ void BaseFuncGraphEvaluator::EnterStackFrame(const AnalysisEnginePtr &engine, co
|
|||
<< "), enter, function call depth: " << FunctionCallDepth() << " - " << StackFrameDepth();
|
||||
}
|
||||
|
||||
void BaseFuncGraphEvaluator::LeaveStackFrame(const AnalysisEnginePtr &engine,
|
||||
const StackFramePtr ¤t_stack_frame) {
|
||||
void BaseFuncGraphEvaluator::LeaveStackFrame(const AnalysisEnginePtr &, const StackFramePtr ¤t_stack_frame) {
|
||||
MS_EXCEPTION_IF_NULL(current_stack_frame);
|
||||
// Leave current func graph.
|
||||
auto current_context = current_stack_frame->current_context();
|
||||
trace::TraceGraphEvalLeave(current_context);
|
||||
|
@ -149,8 +153,11 @@ AbstractBasePtr BaseFuncGraphEvaluator::LaunchStackFrame(const AnalysisEnginePtr
|
|||
|
||||
AbstractBasePtr BaseFuncGraphEvaluator::LaunchRecursiveEval(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg,
|
||||
const AnalysisContextPtr &context) {
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
MS_EXCEPTION_IF_NULL(engine);
|
||||
const AnfNodePtr &func_node = fg->get_return();
|
||||
const auto &all_nodes = TopoSort(func_node, SuccIncoming, [](const AnfNodePtr &node) -> IncludeType {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node->isa<ValueNode>() || node->isa<Parameter>()) {
|
||||
return EXCLUDE;
|
||||
}
|
||||
|
@ -162,7 +169,9 @@ AbstractBasePtr BaseFuncGraphEvaluator::LaunchRecursiveEval(const AnalysisEngine
|
|||
MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg << "/" << fg->ToString()
|
||||
<< ", node_conf: " << node_conf->ToString();
|
||||
auto node_eval_result = engine->ObtainEvalResultWithCache(node_conf);
|
||||
MS_EXCEPTION_IF_NULL(node_eval_result);
|
||||
res_base = node_eval_result->abstract();
|
||||
MS_EXCEPTION_IF_NULL(res_base);
|
||||
MS_LOG(DEBUG) << GetInferThread() << "Eval ( " << node_conf->ToString() << ") = " << res_base->ToString();
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(res_base);
|
||||
|
@ -254,6 +263,7 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
|
|||
}
|
||||
|
||||
void BroadenArgs(const AbstractBasePtrList &args_spec_list, AbstractBasePtrList *broaded_args) {
|
||||
MS_EXCEPTION_IF_NULL(broaded_args);
|
||||
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(*broaded_args),
|
||||
[](const AbstractBasePtr &arg) -> AbstractBasePtr {
|
||||
MS_EXCEPTION_IF_NULL(arg);
|
||||
|
@ -467,6 +477,7 @@ EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &arg
|
|||
|
||||
// Call the original evaluator, get the result: y = f(x)
|
||||
EvalResultPtr result = evaluator_->Run(engine, args_conf_list, nullptr);
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
// Build a virtual function: bprop_f which use sense of y as input, return sense of function free variable and input
|
||||
// parameters. (sense_f, sense_x, ...)(*bpro_f) (sense_y)
|
||||
AbstractBasePtrList bparams;
|
||||
|
|
|
@ -228,7 +228,7 @@ class BaseFuncGraphEvaluator : public Evaluator {
|
|||
const AnalysisContextPtr &context);
|
||||
static void EnterStackFrame(const AnalysisEnginePtr &engine, const StackFramePtr ¤t_stack_frame,
|
||||
const StackFramePtr &new_stack_frame);
|
||||
static void LeaveStackFrame(const AnalysisEnginePtr &engine, const StackFramePtr ¤t_stack_frame);
|
||||
static void LeaveStackFrame(const AnalysisEnginePtr &, const StackFramePtr ¤t_stack_frame);
|
||||
};
|
||||
|
||||
class FuncGraphEvaluator : public BaseFuncGraphEvaluator {
|
||||
|
|
|
@ -199,6 +199,7 @@ class OrderEnforcer {
|
|||
}
|
||||
|
||||
void EnforceOrderForOtherCNode(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
// Find refs from the cnode inputs.
|
||||
auto &inputs = cnode->inputs();
|
||||
const size_t last_index = inputs.size() - 1;
|
||||
|
@ -232,6 +233,7 @@ class OrderEnforcer {
|
|||
}
|
||||
|
||||
bool IsInUpdateState(const AnfNodePtr &load_user, const CNodePtr &update_state) {
|
||||
MS_EXCEPTION_IF_NULL(update_state);
|
||||
const size_t attach_index = 2;
|
||||
const size_t input_size = update_state->inputs().size();
|
||||
for (size_t index = attach_index; index < input_size; index++) {
|
||||
|
@ -368,6 +370,7 @@ class OrderEnforcer {
|
|||
|
||||
// Enforce order of execution for Load users node.
|
||||
void OrderEnforce(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
OrderEnforcer enforcer(func_graph);
|
||||
enforcer.Run();
|
||||
auto fg_used_total = func_graph->func_graphs_used_total();
|
||||
|
|
|
@ -52,10 +52,17 @@ std::unordered_set<std::string> prims_to_skip_undetermined_infer{
|
|||
|
||||
EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
const AnfNodeConfigPtr &out_conf) {
|
||||
MS_EXCEPTION_IF_NULL(engine);
|
||||
MS_EXCEPTION_IF_NULL(out_conf);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
||||
[](const ConfigPtr &ref) -> AbstractBasePtr { return ref->ObtainEvalResult()->abstract(); });
|
||||
[](const ConfigPtr &ref) -> AbstractBasePtr {
|
||||
MS_EXCEPTION_IF_NULL(ref);
|
||||
MS_EXCEPTION_IF_NULL(ref->ObtainEvalResult());
|
||||
return ref->ObtainEvalResult()->abstract();
|
||||
});
|
||||
auto do_signature = prim_->cast<prim::DoSignaturePrimitivePtr>();
|
||||
MS_EXCEPTION_IF_NULL(do_signature);
|
||||
auto &func = do_signature->function();
|
||||
if (func->isa<Primitive>()) {
|
||||
auto sig_prim = func->cast<PrimitivePtr>();
|
||||
|
@ -67,17 +74,16 @@ EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
|
|||
}
|
||||
}
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(out_conf);
|
||||
if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
|
||||
}
|
||||
|
||||
auto out_node = dyn_cast<CNode>(out_conf->node());
|
||||
MS_EXCEPTION_IF_NULL(out_node);
|
||||
const auto &out_node_inputs = out_node->inputs();
|
||||
if (out_node->inputs().size() == 0 || (out_node_inputs.size() - 1) != args_conf_list.size()) {
|
||||
MS_LOG(EXCEPTION) << "Op: " << do_signature->function()->ToString()
|
||||
<< " args size should equal to inputs size minus 1, but args size " << args_conf_list.size()
|
||||
<< ", inputs size " << out_node_inputs.size();
|
||||
MS_LOG(EXCEPTION) << "Op: " << func->ToString() << " args size should equal to inputs size minus 1, but args size "
|
||||
<< args_conf_list.size() << ", inputs size " << out_node_inputs.size();
|
||||
}
|
||||
AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()};
|
||||
|
||||
|
@ -90,11 +96,9 @@ EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
|
|||
AnfNodePtr new_node = nullptr;
|
||||
if (bound_node() != nullptr) {
|
||||
TraceGuard trace_guard(std::make_shared<TraceDoSignature>(bound_node()->debug_info()));
|
||||
new_node = prim::GenerateCNode(out_node->func_graph(), prim_->ToString(), do_signature->function(), args_spec_list,
|
||||
args_inputs);
|
||||
new_node = prim::GenerateCNode(out_node->func_graph(), prim_->ToString(), func, args_spec_list, args_inputs);
|
||||
} else {
|
||||
new_node = prim::GenerateCNode(out_node->func_graph(), prim_->ToString(), do_signature->function(), args_spec_list,
|
||||
args_inputs);
|
||||
new_node = prim::GenerateCNode(out_node->func_graph(), prim_->ToString(), func, args_spec_list, args_inputs);
|
||||
}
|
||||
AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_node, out_conf->context(), out_conf->func_graph());
|
||||
|
||||
|
@ -137,12 +141,17 @@ static AbstractBasePtrList GetUnpackGraphSpecArgsList(AbstractBasePtrList args_s
|
|||
|
||||
EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
const AnfNodeConfigPtr &out_conf) {
|
||||
MS_EXCEPTION_IF_NULL(engine);
|
||||
MS_EXCEPTION_IF_NULL(out_conf);
|
||||
MS_EXCEPTION_IF_NULL(out_conf->node());
|
||||
if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
|
||||
}
|
||||
|
||||
auto unpack_graph = prim_->cast<prim::UnpackGraphPrimitivePtr>();
|
||||
MS_EXCEPTION_IF_NULL(unpack_graph);
|
||||
auto out_node = out_conf->node()->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(out_node);
|
||||
const auto &out_node_inputs = out_node->inputs();
|
||||
if (out_node->inputs().empty() || (out_node_inputs.size() - 1) != args_conf_list.size()) {
|
||||
MS_LOG(EXCEPTION) << "UnpackGraphPrimitive"
|
||||
|
@ -152,8 +161,15 @@ EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
|
|||
AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()};
|
||||
AbstractBasePtrList args_spec_list;
|
||||
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
||||
[](const ConfigPtr &ref) -> AbstractBasePtr { return ref->ObtainEvalResult()->abstract(); });
|
||||
[](const ConfigPtr &ref) -> AbstractBasePtr {
|
||||
MS_EXCEPTION_IF_NULL(ref);
|
||||
MS_EXCEPTION_IF_NULL(ref->ObtainEvalResult());
|
||||
return ref->ObtainEvalResult()->abstract();
|
||||
});
|
||||
// get the forward graph
|
||||
if (args_spec_list.empty()) {
|
||||
MS_LOG(EXCEPTION) << "args_spec_list can't be empty.";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
|
||||
auto fn = args_spec_list[0]->cast<AbstractFunctionPtr>();
|
||||
if (fn == nullptr) {
|
||||
|
@ -165,7 +181,6 @@ EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
|
|||
MS_EXCEPTION_IF_NULL(forward_graph);
|
||||
AbstractBasePtrList graph_specialize_args =
|
||||
GetUnpackGraphSpecArgsList(args_spec_list, unpack_graph->need_unpack_args());
|
||||
|
||||
AbstractBasePtrList graph_specialize_args_without_sens;
|
||||
if (unpack_graph->with_sens_in_args() && graph_specialize_args.empty()) {
|
||||
MS_EXCEPTION(ValueError) << "Grad with sens, but the sens is not provided.";
|
||||
|
@ -188,6 +203,8 @@ EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
|
|||
|
||||
AnfNodePtr MixedPrecisionCastHelper(const AnfNodePtr &source_node, const AbstractBasePtr &node_type,
|
||||
const AnfNodePtr &target_type, const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(node_type);
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
AnfNodePtr target_node = source_node;
|
||||
if (node_type->isa<AbstractTensor>()) {
|
||||
auto x = node_type->cast<AbstractTensorPtr>();
|
||||
|
@ -239,12 +256,14 @@ AnfNodePtr MixedPrecisionCastHelper(const AnfNodePtr &source_node, const Abstrac
|
|||
|
||||
EvalResultPtr MixedPrecisionCastEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
const AnfNodeConfigPtr &out_conf) {
|
||||
MS_EXCEPTION_IF_NULL(engine);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
MS_EXCEPTION_IF_NULL(out_conf);
|
||||
if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
|
||||
}
|
||||
auto out_node = out_conf->node()->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(out_node);
|
||||
const auto &out_node_inputs = out_node->inputs();
|
||||
if (out_node->inputs().empty() || (out_node_inputs.size() - 1) != args_conf_list.size()) {
|
||||
MS_LOG(EXCEPTION) << "MixedPrecisionCast"
|
||||
|
@ -258,8 +277,12 @@ EvalResultPtr MixedPrecisionCastEvaluator::Run(AnalysisEnginePtr engine, const C
|
|||
scope = out_conf->node()->scope();
|
||||
ScopeGuard scope_guard(scope);
|
||||
|
||||
FuncGraphPtr func_graph = out_conf->node()->func_graph();
|
||||
FuncGraphPtr func_graph = out_node->func_graph();
|
||||
constexpr size_t source_node_index = 2;
|
||||
if (out_node_inputs.size() <= source_node_index) {
|
||||
MS_LOG(EXCEPTION) << "Input size:" << out_node_inputs.size() << " should bigger than 2.";
|
||||
}
|
||||
|
||||
AnfNodePtr new_node =
|
||||
MixedPrecisionCastHelper(out_node_inputs[source_node_index], args_spec_list[1], out_node_inputs[1], func_graph);
|
||||
AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_node, out_conf->context(), out_conf->func_graph());
|
||||
|
@ -282,6 +305,7 @@ py::object BuildValue(const ValuePtr &value_ptr) {
|
|||
|
||||
py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base) {
|
||||
auto arg_tuple = dyn_cast<AbstractTuple>(abs_base);
|
||||
MS_EXCEPTION_IF_NULL(arg_tuple);
|
||||
size_t len = arg_tuple->size();
|
||||
py::tuple shape_tuple(len);
|
||||
py::tuple dtype_tuple(len);
|
||||
|
@ -317,6 +341,7 @@ py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base) {
|
|||
auto dic = py::dict();
|
||||
dic[ATTR_SHAPE] = shape_tuple;
|
||||
dic[ATTR_DTYPE] = dtype_tuple;
|
||||
MS_EXCEPTION_IF_NULL(arg_tuple->BuildValue());
|
||||
if (arg_tuple->BuildValue()->isa<AnyValue>()) {
|
||||
dic[ATTR_VALUE] = py::none();
|
||||
} else {
|
||||
|
@ -337,6 +362,7 @@ py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base) {
|
|||
|
||||
py::dict AbstractListToPython(const AbstractBasePtr &abs_base) {
|
||||
auto arg_list = dyn_cast<AbstractList>(abs_base);
|
||||
MS_EXCEPTION_IF_NULL(arg_list);
|
||||
size_t len = arg_list->size();
|
||||
py::list shape_list(len);
|
||||
py::list dtype_list(len);
|
||||
|
@ -361,6 +387,7 @@ py::dict AbstractListToPython(const AbstractBasePtr &abs_base) {
|
|||
auto dic = py::dict();
|
||||
dic[ATTR_SHAPE] = shape_list;
|
||||
dic[ATTR_DTYPE] = dtype_list;
|
||||
MS_EXCEPTION_IF_NULL(arg_list->BuildValue());
|
||||
if (arg_list->BuildValue()->isa<AnyValue>()) {
|
||||
dic[ATTR_VALUE] = py::none();
|
||||
} else {
|
||||
|
@ -377,6 +404,9 @@ py::dict AbstractListToPython(const AbstractBasePtr &abs_base) {
|
|||
|
||||
void ConvertAbstractTensorToPython(const AbstractBasePtr &abs_base, py::dict *dic) {
|
||||
auto arg_tensor = dyn_cast<AbstractTensor>(abs_base);
|
||||
MS_EXCEPTION_IF_NULL(dic);
|
||||
MS_EXCEPTION_IF_NULL(arg_tensor);
|
||||
MS_EXCEPTION_IF_NULL(arg_tensor->shape());
|
||||
(*dic)[ATTR_SHAPE] = arg_tensor->shape()->shape();
|
||||
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
|
||||
const auto &min_shape = arg_tensor->shape()->min_shape();
|
||||
|
@ -399,12 +429,15 @@ void ConvertAbstractTensorToPython(const AbstractBasePtr &abs_base, py::dict *di
|
|||
}
|
||||
|
||||
void ConvertAbstractFunctionToPython(const AbstractBasePtr &abs_base, py::dict *dic) {
|
||||
MS_EXCEPTION_IF_NULL(dic);
|
||||
MS_EXCEPTION_IF_NULL(abs_base);
|
||||
(*dic)[ATTR_SHAPE] = py::none();
|
||||
(*dic)[ATTR_DTYPE] = abs_base->BuildType();
|
||||
(*dic)[ATTR_VALUE] = py::none();
|
||||
if (abs_base->isa<PartialAbstractClosure>()) {
|
||||
AbstractBasePtrList args = abs_base->cast<PartialAbstractClosurePtr>()->args();
|
||||
if (!args.empty()) {
|
||||
MS_EXCEPTION_IF_NULL(args[0]->BuildValue());
|
||||
auto value = args[0]->BuildValue()->cast<parse::ClassTypePtr>();
|
||||
if (value != nullptr) {
|
||||
(*dic)[ATTR_DTYPE] = std::make_shared<TypeType>();
|
||||
|
@ -466,6 +499,7 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
|
|||
dic[ATTR_VALUE] = py::none();
|
||||
} else {
|
||||
auto value = abs_base->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
if ((*value == *kAnyValue)) {
|
||||
auto value_desc = abs_base->value_desc();
|
||||
MS_EXCEPTION(TypeError) << "Unsupported parameter " << (value_desc.empty() ? "type" : value_desc)
|
||||
|
@ -490,6 +524,8 @@ py::tuple PreparePyInputs(const PrimitivePyPtr &prim_py, const AbstractBasePtrLi
|
|||
}
|
||||
|
||||
void CheckCustomPrimOutputInferResult(const PrimitivePtr &prim, const AbstractBasePtr &res_spec) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
MS_EXCEPTION_IF_NULL(res_spec);
|
||||
const string kOutputNum = "output_num";
|
||||
if (prim->IsCustomPrim()) {
|
||||
// Raise error if output_num is not match the infer result.
|
||||
|
@ -609,8 +645,10 @@ EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, c
|
|||
MS_EXCEPTION_IF_NULL(context);
|
||||
bool need_infer_value =
|
||||
(!(eval_impl_.in_white_list_) || (context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode)) &&
|
||||
std::all_of(args.begin(), args.end(),
|
||||
[](const AbstractBasePtr &abs) -> bool { return (abs->BuildValue() != nullptr); });
|
||||
std::all_of(args.begin(), args.end(), [](const AbstractBasePtr &abs) -> bool {
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
return (abs->BuildValue() != nullptr);
|
||||
});
|
||||
AbstractBasePtr abs_base = nullptr;
|
||||
ValuePtr value = nullptr;
|
||||
prim_->BeginRecordAddAttr();
|
||||
|
@ -684,8 +722,16 @@ EvalResultPtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const Ab
|
|||
TypePtrList selections;
|
||||
MS_EXCEPTION_IF_NULL(item.second);
|
||||
(void)std::transform(item.second->begin(), item.second->end(), std::back_inserter(selections),
|
||||
[&args](size_t arg_idx) -> TypePtr { return args[arg_idx]->GetTypeTrack(); });
|
||||
[&args](size_t arg_idx) -> TypePtr {
|
||||
if (arg_idx >= args.size()) {
|
||||
MS_LOG(EXCEPTION) << "Index:" << arg_idx << " out of range:" << args.size();
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(args[arg_idx]);
|
||||
return args[arg_idx]->GetTypeTrack();
|
||||
});
|
||||
TypePtr res = CheckTypeList(item.first, selections);
|
||||
MS_EXCEPTION_IF_NULL(return_value_type_);
|
||||
MS_EXCEPTION_IF_NULL(item.first);
|
||||
if (*return_value_type_ == *(item.first)) {
|
||||
ret_value_type = res;
|
||||
}
|
||||
|
@ -806,6 +852,7 @@ EvalResultPtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_
|
|||
input.push_back(conf->node());
|
||||
MS_EXCEPTION_IF_NULL(old_conf);
|
||||
FuncGraphPtr func_graph = old_conf->node()->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
CNodePtr new_cnode = func_graph->NewCNode(input);
|
||||
if (require_type == REQUIRE_TYPE::ATTR) {
|
||||
new_cnode = func_graph->NewCNode({new_cnode});
|
||||
|
@ -829,11 +876,13 @@ EvalResultPtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &engin
|
|||
MS_LOG(DEBUG) << "Args[0]: " << args_spec_list[0]->ToString();
|
||||
MS_LOG(DEBUG) << "Args[1]: " << args_spec_list[1]->ToString();
|
||||
auto data_v = args_spec_list[0]->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(data_v);
|
||||
if (!data_v->isa<parse::NameSpace>()) {
|
||||
MS_LOG(EXCEPTION) << "Data is not NameSpace : " << data_v->ToString();
|
||||
}
|
||||
|
||||
auto item_v = args_spec_list[1]->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(item_v);
|
||||
if (item_v->isa<StringImm>()) {
|
||||
item_v = std::make_shared<parse::Symbol>(item_v->cast<StringImmPtr>()->value());
|
||||
}
|
||||
|
@ -845,9 +894,10 @@ EvalResultPtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &engin
|
|||
// item_name to func addr from obj_map
|
||||
parse::SymbolPtr symbol = item_v->cast<parse::SymbolPtr>();
|
||||
parse::NameSpacePtr name_space = data_v->cast<parse::NameSpacePtr>();
|
||||
MS_EXCEPTION_IF_NULL(out_conf);
|
||||
auto out_node = out_conf->node();
|
||||
FuncGraphPtr func_graph = out_node->func_graph();
|
||||
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto new_node = parse::ResolveSymbol(func_graph->manager(), name_space, symbol, out_node);
|
||||
if (new_node == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Resolve node failed";
|
||||
|
@ -857,6 +907,7 @@ EvalResultPtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &engin
|
|||
func_graph->ReplaceInOrder(out_node, new_node);
|
||||
|
||||
AnalysisEnginePtr eng = out_conf->engine();
|
||||
MS_EXCEPTION_IF_NULL(eng);
|
||||
AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_node, out_conf->context(), out_conf->func_graph());
|
||||
return eng->ForwardConfig(out_conf, fn_conf);
|
||||
}
|
||||
|
@ -877,7 +928,7 @@ EvalResultPtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &eng
|
|||
std::string item_name = item_v->cast<StringImmPtr>()->value();
|
||||
MS_LOG(DEBUG) << "Resolve name: " << cls->tag().name();
|
||||
MS_LOG(DEBUG) << "Resolve item: " << item_name;
|
||||
|
||||
MS_EXCEPTION_IF_NULL(cls);
|
||||
AbstractBasePtr attr = cls->GetAttribute(item_name);
|
||||
if (attr != nullptr) {
|
||||
return std::make_shared<EvalResult>(attr, nullptr);
|
||||
|
@ -885,6 +936,8 @@ EvalResultPtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &eng
|
|||
|
||||
ValuePtr method = cls->GetMethod(item_name);
|
||||
if (method->isa<AnyValue>()) {
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[0]->BuildType());
|
||||
MS_EXCEPTION(AttributeError) << "Unknown field, data type: " << args_spec_list[0]->BuildType()->ToString()
|
||||
<< ", item value: " << item_v->ToString();
|
||||
}
|
||||
|
@ -920,6 +973,7 @@ EvalResultPtr GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePt
|
|||
if (require.is<std::string>()) {
|
||||
// composite registered in standard_method_map go to this branch
|
||||
converted_v = prim::GetPythonOps(require.cast<std::string>());
|
||||
MS_EXCEPTION_IF_NULL(converted_v);
|
||||
if (!converted_v->isa<Primitive>()) {
|
||||
AddToManager(engine, converted_v->cast<FuncGraphPtr>());
|
||||
}
|
||||
|
@ -945,6 +999,7 @@ EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePt
|
|||
scope = out_conf->node()->scope();
|
||||
}
|
||||
ScopeGuard scope_guard(scope);
|
||||
MS_EXCEPTION_IF_NULL(item_value);
|
||||
if (item_value->isa<AnyValue>()) {
|
||||
MS_LOG(EXCEPTION) << "The value of the attribute could not be inferred: " << item_value->ToString();
|
||||
}
|
||||
|
@ -973,7 +1028,7 @@ class EmbedEvaluator : public SymbolicPrimEvaluator {
|
|||
}
|
||||
AnfNodeConfigPtr node_conf = dyn_cast<AnfNodeConfig>(args_conf_list[0]);
|
||||
MS_EXCEPTION_IF_NULL(node_conf);
|
||||
|
||||
MS_EXCEPTION_IF_NULL(node_conf->ObtainEvalResult());
|
||||
AbstractBasePtr x = node_conf->ObtainEvalResult()->abstract();
|
||||
x = SensitivityTransform(x);
|
||||
SymbolicKeyInstancePtr key = std::make_shared<SymbolicKeyInstance>(node_conf->node(), x);
|
||||
|
@ -983,12 +1038,12 @@ class EmbedEvaluator : public SymbolicPrimEvaluator {
|
|||
};
|
||||
|
||||
static AnfNodePtr FindParameterNodeByString(const FuncGraphManagerPtr &manager, const std::string &name) {
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
auto root_g_set = manager->roots();
|
||||
if (root_g_set.size() != 1) {
|
||||
return nullptr;
|
||||
}
|
||||
const FuncGraphPtr &root_g = root_g_set.back();
|
||||
|
||||
for (auto ¶m_node : root_g->parameters()) {
|
||||
auto param = param_node->cast<ParameterPtr>();
|
||||
if (param && name == param->name()) {
|
||||
|
@ -1014,8 +1069,9 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
|
|||
MS_LOG(ERROR) << "Conf should be AnfNodeConfig";
|
||||
return nullptr;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(node_conf->ObtainEvalResult());
|
||||
AbstractBasePtr abs = node_conf->ObtainEvalResult()->abstract();
|
||||
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
AbstractRefPtr ref_abs = abs->cast<AbstractRefPtr>();
|
||||
if (ref_abs == nullptr) {
|
||||
MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref, but " << abs->ToString();
|
||||
|
@ -1126,6 +1182,7 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
|
|||
// Get the type parameter.
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
|
||||
TypePtr type = args_spec_list[0]->GetTypeTrack();
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
if (type->type_id() != kMetaTypeTypeType) {
|
||||
MS_LOG(EXCEPTION) << "CreateInstanceEvaluator require first parameter should be an object of TypeType, but got "
|
||||
<< type->ToString();
|
||||
|
@ -1179,6 +1236,9 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
|
|||
// Exclude class type by minus 1;
|
||||
std::size_t params_size = args_spec_list.size() - 1;
|
||||
auto params = py::tuple(params_size);
|
||||
if (params_size > params.size()) {
|
||||
MS_LOG(EXCEPTION) << "Unexpected params_size:" << params_size << ",params.size():" << params.size();
|
||||
}
|
||||
if (params_size > 0) {
|
||||
for (size_t i = 0; i < params_size; i++) {
|
||||
// Only support the Scalar parameters type. Bypass class type by offset with 1.
|
||||
|
@ -1209,9 +1269,11 @@ class PartialEvaluator : public Evaluator {
|
|||
MS_EXCEPTION_IF_NULL(args_conf_list[0]);
|
||||
MS_EXCEPTION_IF_NULL(args_conf_list[0]->ObtainEvalResult());
|
||||
auto arg0_value = args_conf_list[0]->ObtainEvalResult()->abstract();
|
||||
MS_EXCEPTION_IF_NULL(arg0_value);
|
||||
AbstractBasePtrList args_spec_list{arg0_value};
|
||||
// Func in hypermap(partial(Func, arg0), arg1, arg2) may become Poly Node.
|
||||
if (arg0_value->isa<AbstractError>()) {
|
||||
MS_EXCEPTION_IF_NULL(arg0_value->GetValueTrack());
|
||||
auto ret = std::make_shared<AbstractError>(arg0_value->GetValueTrack()->cast<StringImmPtr>(), out_conf->node());
|
||||
MS_LOG(DEBUG) << "AbstractError for node: " << out_conf->node()->DebugString()
|
||||
<< " as func is: " << arg0_value->ToString();
|
||||
|
@ -1259,7 +1321,8 @@ class PartialEvaluator : public Evaluator {
|
|||
}
|
||||
|
||||
EvalResultPtr HandleDoSignature(const AnalysisEnginePtr &engine, const ValuePtr &signature_value,
|
||||
const AnfNodeConfigPtr &out_conf = nullptr) const {
|
||||
const AnfNodeConfigPtr &out_conf) const {
|
||||
MS_EXCEPTION_IF_NULL(engine);
|
||||
MS_EXCEPTION_IF_NULL(out_conf);
|
||||
MS_EXCEPTION_IF_NULL(out_conf->node());
|
||||
auto cnode = out_conf->node()->cast<CNodePtr>();
|
||||
|
@ -1273,7 +1336,7 @@ class PartialEvaluator : public Evaluator {
|
|||
|
||||
ScopePtr scope = out_conf->node()->scope();
|
||||
ScopeGuard scope_guard(scope);
|
||||
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
CNodePtr new_cnode = func_graph->NewCNode(new_nodes_inputs);
|
||||
AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_cnode, out_conf->context(), out_conf->func_graph());
|
||||
return engine->ForwardConfig(out_conf, fn_conf);
|
||||
|
@ -1456,7 +1519,7 @@ bool IsSubtypeClass(const AbstractBasePtr x, const TypePtr model) {
|
|||
if (model->IsGeneric()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(model_class);
|
||||
if (x_class->tag() == model_class->tag()) {
|
||||
auto m_attributes = model_class->GetAttributes();
|
||||
auto x_attributes = x_class->attributes();
|
||||
|
|
|
@ -32,13 +32,16 @@ namespace mindspore {
|
|||
namespace abstract {
|
||||
namespace {
|
||||
inline AbstractBasePtr GetEvaluatedValue(const AnfNodeConfigPtr &conf) {
|
||||
MS_EXCEPTION_IF_NULL(conf);
|
||||
if (conf->node()->intermediate_abstract()) {
|
||||
return conf->node()->intermediate_abstract();
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(conf->ObtainEvalResult());
|
||||
return conf->ObtainEvalResult()->abstract();
|
||||
}
|
||||
|
||||
AnfNodePtr BuildValueNode(const ValuePtr &v, const AbstractBasePtr &abs_base) {
|
||||
MS_EXCEPTION_IF_NULL(abs_base);
|
||||
AnfNodePtr value_node = NewValueNode(v);
|
||||
value_node->set_abstract(abs_base);
|
||||
MS_LOG(DEBUG) << "Create ValueNode: " << value_node->ToString() << ", with abstract: " << abs_base->ToString();
|
||||
|
@ -53,6 +56,7 @@ bool IsVisible(FuncGraphPtr fg, const FuncGraphPtr &parent) {
|
|||
}
|
||||
|
||||
bool CheckAbstractTensor(const AbstractBasePtr &abs_base) {
|
||||
MS_EXCEPTION_IF_NULL(abs_base);
|
||||
if (abs_base->isa<AbstractTensor>()) {
|
||||
return true;
|
||||
} else if (abs_base->isa<AbstractSequeue>()) {
|
||||
|
@ -69,7 +73,8 @@ bool CheckAbstractTensor(const AbstractBasePtr &abs_base) {
|
|||
FuncGraphPtr ProgramSpecializer::Run(const FuncGraphPtr &fg, const AnalysisContextPtr &context) {
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
MS_LOG(DEBUG) << "Specialize topmost function graph: " << context->func_graph()->ToString();
|
||||
MS_LOG(DEBUG) << "Specialize topmost function graph: "
|
||||
<< (context->func_graph() ? context->func_graph()->ToString() : "FG(Null)");
|
||||
if (top_context_ == nullptr) {
|
||||
top_context_ = context;
|
||||
MS_LOG(INFO) << "Specialize set top func graph context: " << context->ToString();
|
||||
|
@ -82,6 +87,7 @@ FuncGraphPtr ProgramSpecializer::SpecializeFuncGraph(const FuncGraphPtr &fg, con
|
|||
MS_EXCEPTION_IF_NULL(context);
|
||||
auto iter = specializations_.find(context->SpecializeKey());
|
||||
if (iter != specializations_.end()) {
|
||||
MS_EXCEPTION_IF_NULL(iter->second);
|
||||
return iter->second->specialized_func_graph();
|
||||
}
|
||||
|
||||
|
@ -132,11 +138,11 @@ AnfNodePtr FuncGraphSpecializer::ReplicateDisconnectedNode(const AnfNodePtr &nod
|
|||
std::shared_ptr<FuncGraphSpecializer> specializer = GetTopSpecializer(node);
|
||||
|
||||
// If had replicated, just return that.
|
||||
MS_EXCEPTION_IF_NULL(specializer->repl_node_);
|
||||
auto iter = specializer->repl_node_->find(node);
|
||||
if (iter != specializer->repl_node_->end()) {
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
auto new_node = specializer->cloner_->CloneDisconnected(node);
|
||||
if (node->isa<CNode>()) {
|
||||
if (!new_node->isa<CNode>()) {
|
||||
|
@ -157,6 +163,7 @@ AnfNodePtr FuncGraphSpecializer::ReplicateDisconnectedNode(const AnfNodePtr &nod
|
|||
}
|
||||
|
||||
void FuncGraphSpecializer::UpdateNewCNodeInputs(const AnfNodePtr &node, const AnfNodePtr &new_node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto c_node = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(c_node);
|
||||
auto inputs = c_node->inputs();
|
||||
|
@ -170,6 +177,7 @@ void FuncGraphSpecializer::UpdateNewCNodeInputs(const AnfNodePtr &node, const An
|
|||
MS_EXCEPTION_IF_NULL(c_inp);
|
||||
auto c_new_inp = new_inp->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(c_new_inp);
|
||||
MS_EXCEPTION_IF_NULL(c_new_inp->func_graph());
|
||||
MS_LOG(DEBUG) << "Replace in order, inp node: " << inp->DebugString() << " -> " << new_inp->DebugString();
|
||||
c_new_inp->func_graph()->ReplaceInOrder(c_inp, c_new_inp);
|
||||
}
|
||||
|
@ -208,8 +216,9 @@ std::shared_ptr<FuncGraphSpecializer> FuncGraphSpecializer::GetTopSpecializer(co
|
|||
MS_EXCEPTION_IF_NULL(specializer_->top_context());
|
||||
if (specializer_->top_context()->func_graph() == fg) { // `fg` is top func graph.
|
||||
specializer = specializer_->GetFuncGraphSpecializer(specializer_->top_context());
|
||||
MS_LOG(INFO) << "Used top func graph specializer as parent for " << func_graph_->ToString()
|
||||
<< ", node: " << node->DebugString() << ", NodeInfo: " << trace::GetDebugInfo(node->debug_info());
|
||||
MS_LOG(INFO) << "Used top func graph specializer as parent for "
|
||||
<< (func_graph_ ? func_graph_->ToString() : "FG(Null)") << ", node: " << node->DebugString()
|
||||
<< ", NodeInfo: " << trace::GetDebugInfo(node->debug_info());
|
||||
MS_EXCEPTION_IF_NULL(specializer);
|
||||
break;
|
||||
}
|
||||
|
@ -219,21 +228,28 @@ std::shared_ptr<FuncGraphSpecializer> FuncGraphSpecializer::GetTopSpecializer(co
|
|||
if (specializer == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "`specializer` should not be null, node: " << node->DebugString()
|
||||
<< ", NodeInfo: " << trace::GetDebugInfo(node->debug_info()) << ".\n"
|
||||
<< func_graph_->ToString() << " has no parent context? At least not " << fg->ToString();
|
||||
<< (func_graph_ ? func_graph_->ToString() : "FG(Null)")
|
||||
<< " has no parent context? At least not " << fg->ToString();
|
||||
}
|
||||
}
|
||||
return specializer;
|
||||
}
|
||||
|
||||
void FuncGraphSpecializer::Run() {
|
||||
MS_LOG(DEBUG) << "Before run, origin func graph name: " << func_graph_->ToString()
|
||||
<< ", cloned func graph name: " << specialized_func_graph_->ToString()
|
||||
<< ", func graph: " << func_graph_->get_return()->DebugString();
|
||||
MS_LOG(DEBUG) << "Before run, origin func graph name: " << (func_graph_ ? func_graph_->ToString() : "FG(Null)")
|
||||
<< ", cloned func graph name: "
|
||||
<< (specialized_func_graph_ ? specialized_func_graph_->ToString() : "FG(Null)") << ", func graph: "
|
||||
<< (func_graph_ ? func_graph_->get_return() ? func_graph_->get_return()->DebugString() : "return null"
|
||||
: "FG(null)");
|
||||
FirstPass();
|
||||
SecondPass();
|
||||
MS_LOG(DEBUG) << "After run, origin func graph name: " << func_graph_->ToString()
|
||||
<< ", cloned func graph name: " << specialized_func_graph_->ToString()
|
||||
<< ", new func graph: " << specialized_func_graph_->get_return()->DebugString();
|
||||
MS_LOG(DEBUG) << "After run, origin func graph name: " << (func_graph_ ? func_graph_->ToString() : "FG(Null)")
|
||||
<< ", cloned func graph name: "
|
||||
<< (specialized_func_graph_ ? specialized_func_graph_->ToString() : "FG(Null)") << ", new func graph: "
|
||||
<< (specialized_func_graph_ ? specialized_func_graph_->get_return()
|
||||
? specialized_func_graph_->get_return()->DebugString()
|
||||
: "return null"
|
||||
: "FG(null)");
|
||||
}
|
||||
|
||||
void FuncGraphSpecializer::FirstPass() {
|
||||
|
@ -291,8 +307,8 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) {
|
|||
MS_EXCEPTION_IF_NULL(new_node);
|
||||
if (new_node->func_graph() != specialized_func_graph_) {
|
||||
MS_LOG(EXCEPTION) << "Error in specializer [A] node: " << node->DebugString()
|
||||
<< ", new_node: " << new_node->DebugString()
|
||||
<< ", new_node->func_graph(): " << new_node->func_graph()->ToString()
|
||||
<< ", new_node: " << new_node->DebugString() << ", new_node->func_graph(): "
|
||||
<< (new_node->func_graph() ? new_node->func_graph()->ToString() : "FG(Null)")
|
||||
<< ", specialized_func_graph_: " << specialized_func_graph_->ToString();
|
||||
return;
|
||||
}
|
||||
|
@ -310,6 +326,7 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) {
|
|||
auto attrs = conf->ObtainEvalResult()->attribute();
|
||||
auto c_old = node->cast<CNodePtr>();
|
||||
auto c_new = new_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(c_new);
|
||||
auto new_inputs = c_new->inputs();
|
||||
auto old_inputs = c_old->inputs();
|
||||
for (size_t i = 0; i < old_inputs.size(); ++i) {
|
||||
|
@ -321,7 +338,6 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) {
|
|||
AnfNodePtr replace_node = BuildPossibleValueNode(iconf->node(), ival, attrs);
|
||||
if (replace_node == nullptr) {
|
||||
replace_node = BuildReplacedNode(iconf);
|
||||
MS_EXCEPTION_IF_NULL(replace_node);
|
||||
replace_node->set_abstract(ival);
|
||||
MS_LOG(DEBUG) << "Set replaced: " << replace_node->ToString() << ", to abstract: " << ival->ToString();
|
||||
} else {
|
||||
|
@ -343,7 +359,7 @@ AnfNodePtr FuncGraphSpecializer::BuildReplacedNode(const AnfNodeConfigPtr &conf)
|
|||
auto conf_iter = engine_->anfnode_config_map().find(conf);
|
||||
AnfNodeConfigPtr new_conf = conf;
|
||||
while (conf_iter != engine_->anfnode_config_map().end()) {
|
||||
MS_LOG(DEBUG) << "Origin conf: node(" << new_conf->node()->DebugString() << ")";
|
||||
MS_LOG(DEBUG) << "Origin conf: node(" << (new_conf->node() ? new_conf->node()->DebugString() : "Node(Null)") << ")";
|
||||
new_conf = conf_iter->second;
|
||||
MS_EXCEPTION_IF_NULL(new_conf);
|
||||
const auto &forward_node = new_conf->node();
|
||||
|
@ -366,15 +382,17 @@ AnfNodePtr FuncGraphSpecializer::BuildReplacedNode(const AnfNodeConfigPtr &conf)
|
|||
// CloneOrderlist, and it will be replaced inside ReplicateDisconnectedNode.
|
||||
// For 2.1 the following code will do the job, replace replicated origin cnode with the replicated
|
||||
// forward one in the replicated func_graph.
|
||||
MS_EXCEPTION_IF_NULL(conf_iter->first);
|
||||
const auto &origin_node = conf_iter->first->node();
|
||||
const auto &replicated_origin_node = GetReplicatedNode(origin_node);
|
||||
if (replicated_origin_node != origin_node) {
|
||||
MS_LOG(DEBUG) << "Replace replicated origin node in order list: " << replicated_origin_node->DebugString()
|
||||
<< ", with replicated forwarded node: " << replicated_forward_node->DebugString();
|
||||
MS_EXCEPTION_IF_NULL(replicated_forward_node->func_graph());
|
||||
replicated_forward_node->func_graph()->ReplaceInOrder(replicated_origin_node, replicated_forward_node);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Origin node is not replicated in specialized func_graph, origin node: "
|
||||
<< origin_node->DebugString();
|
||||
<< (origin_node ? origin_node->DebugString() : "Node(Null)");
|
||||
}
|
||||
}
|
||||
conf_iter = engine_->anfnode_config_map().find(new_conf);
|
||||
|
@ -406,6 +424,7 @@ inline bool CanSpecializeNode(const AnfNodePtr &node) {
|
|||
AnfNodePtr FuncGraphSpecializer::BuildSpecializedNode(const AnfNodePtr &node, const AbstractBasePtr &abs,
|
||||
const AbstractBasePtrList &argvals) {
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
AbstractFunctionPtr real_a = dyn_cast<AbstractFunction>(abs);
|
||||
MS_EXCEPTION_IF_NULL(real_a);
|
||||
|
||||
|
@ -429,9 +448,11 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNode(const AnfNodePtr &node, co
|
|||
}
|
||||
|
||||
// Set the flag, so this MetaFuncGraph will be Re-AutoMonaded.
|
||||
MS_EXCEPTION_IF_NULL(func);
|
||||
if (func->isa<MetaFuncGraphAbstractClosure>()) {
|
||||
auto specialized_fg = GetValueNode<FuncGraphPtr>(repl);
|
||||
if (specialized_fg != nullptr && (argvals.size() > 1) && argvals[argvals.size() - 1]->isa<AbstractUMonad>()) {
|
||||
if (specialized_fg != nullptr && (argvals.size() > 1) && argvals.back() != nullptr &&
|
||||
argvals.back()->isa<AbstractUMonad>()) {
|
||||
specialized_fg->set_flag(mindspore::kFuncGraphFlagReAutoMonad, true);
|
||||
}
|
||||
}
|
||||
|
@ -482,6 +503,7 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AnfNodePtr &nod
|
|||
AnalysisContextPtr context = MakeContext(engine_, real_eval, argvals);
|
||||
MS_LOG(DEBUG) << "Specialize function graph: " << context->func_graph()->ToString() << ", args: " << argvals.size()
|
||||
<< ", graph: " << context->func_graph()->get_return()->DebugString();
|
||||
MS_EXCEPTION_IF_NULL(context->func_graph());
|
||||
if (context->func_graph()->stub()) {
|
||||
MS_LOG(DEBUG) << "Specialize stub function graph, return the original node: " << context->func_graph()->ToString()
|
||||
<< ", args: " << argvals.size() << ", graph: " << context->func_graph()->get_return()->DebugString()
|
||||
|
@ -489,6 +511,7 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AnfNodePtr &nod
|
|||
return node;
|
||||
}
|
||||
FuncGraphPtr v = specializer_->SpecializeFuncGraph(context->func_graph(), context);
|
||||
MS_EXCEPTION_IF_NULL(v);
|
||||
v->set_flag(kFuncGraphFlagUndetermined, false);
|
||||
return BuildValueNode(v, abs);
|
||||
}
|
||||
|
@ -505,8 +528,13 @@ AnalysisContextPtr FuncGraphSpecializer::MakeContext(const AnalysisEnginePtr &en
|
|||
}
|
||||
|
||||
AnfNodePtr FuncGraphSpecializer::BuildSpecializedParameterNode(const CNodePtr &new_node) {
|
||||
MS_EXCEPTION_IF_NULL(new_node);
|
||||
auto new_inputs = new_node->inputs();
|
||||
if (new_inputs.empty()) {
|
||||
MS_LOG(EXCEPTION) << "inputs can't be empty.";
|
||||
}
|
||||
AnfNodePtr func = new_inputs[0];
|
||||
MS_EXCEPTION_IF_NULL(new_inputs[0]);
|
||||
AbstractBasePtr fnval = new_inputs[0]->abstract();
|
||||
|
||||
AbstractBasePtrList args;
|
||||
|
@ -549,6 +577,7 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedParameterNode(const CNodePtr &n
|
|||
partial_node_list.push_back(old_node);
|
||||
}
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(new_node->func_graph());
|
||||
wrapped_node = new_node->func_graph()->NewCNode(partial_node_list);
|
||||
wrapped_node->set_abstract(partial_closure);
|
||||
}
|
||||
|
@ -556,6 +585,7 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedParameterNode(const CNodePtr &n
|
|||
}
|
||||
|
||||
const EvaluatorCacheMgrPtr FuncGraphSpecializer::GetEvalCache(const EvaluatorPtr &eval) {
|
||||
MS_EXCEPTION_IF_NULL(eval);
|
||||
auto cache_iter = evalcaches_.find(eval);
|
||||
if (cache_iter == evalcaches_.end()) {
|
||||
evalcaches_[eval] = eval->evaluator_cache_mgr();
|
||||
|
@ -571,7 +601,11 @@ std::pair<AbstractBasePtrList, AbstractBasePtr> FuncGraphSpecializer::BuildFromB
|
|||
EvalResultPtr ret = nullptr;
|
||||
AbstractBasePtrList broaded_argvals;
|
||||
std::vector<AbstractBasePtrList> args_vector;
|
||||
auto &origin_eval_cache = evalcaches_[eval]->GetCache();
|
||||
auto eval_cache_iter = evalcaches_.find(eval);
|
||||
if (eval_cache_iter == evalcaches_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Evaluator:" << eval->ToString() << " not exist in cache.";
|
||||
}
|
||||
auto &origin_eval_cache = eval_cache_iter->second->GetCache();
|
||||
for (auto &argvals_map : origin_eval_cache) {
|
||||
auto argvals = argvals_map.first;
|
||||
args_vector.push_back(argvals);
|
||||
|
@ -667,6 +701,7 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
|
|||
auto status = FindUniqueArgvals(func_abs, eval, empty_args, &result);
|
||||
MS_LOG(DEBUG) << "FindUniqueArgvals return status: " << status;
|
||||
// if a node is a poly node, or an input parameter is a PartialAbstractClosure, expand it early
|
||||
MS_EXCEPTION_IF_NULL(func->func_graph());
|
||||
if (status == kSpecializeFindUniqueArgvalPoly ||
|
||||
(func->isa<Parameter>() && func->func_graph()->has_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER))) {
|
||||
auto wrapped_node = BuildSpecializedParameterNode(new_node);
|
||||
|
@ -697,6 +732,7 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
|
|||
|
||||
namespace {
|
||||
void DumpEvaluatorCache(const EvaluatorCacheMgrPtr &evaluator_cache_mgr, const AbstractBasePtrList &argvals) {
|
||||
MS_EXCEPTION_IF_NULL(evaluator_cache_mgr);
|
||||
MS_LOG(DEBUG) << "Find unique argvals failed: " << argvals.size() << ", " << argvals << ". Check cache all items.";
|
||||
int64_t i = 0;
|
||||
const EvalResultCache &map = evaluator_cache_mgr->GetCache();
|
||||
|
@ -706,6 +742,7 @@ void DumpEvaluatorCache(const EvaluatorCacheMgrPtr &evaluator_cache_mgr, const A
|
|||
}
|
||||
|
||||
bool IsPolyFunc(const AbstractFunctionPtr &func, const AbstractBasePtrList &argvals) {
|
||||
MS_EXCEPTION_IF_NULL(func);
|
||||
if (func->isa<PrimitiveAbstractClosure>() && argvals.empty()) {
|
||||
MS_LOG(DEBUG) << "High order primitive return POLY.";
|
||||
return true;
|
||||
|
@ -733,6 +770,7 @@ SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunct
|
|||
MS_EXCEPTION_IF_NULL(result);
|
||||
|
||||
EvaluatorCacheMgrPtr evaluator_cache_mgr = eval->evaluator_cache_mgr();
|
||||
MS_EXCEPTION_IF_NULL(evaluator_cache_mgr);
|
||||
auto data = evaluator_cache_mgr->GetValue(argvals);
|
||||
if (data != nullptr) {
|
||||
*result = std::make_pair(argvals, data->abstract());
|
||||
|
@ -740,13 +778,16 @@ SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunct
|
|||
}
|
||||
DumpEvaluatorCache(evaluator_cache_mgr, argvals);
|
||||
|
||||
MS_EXCEPTION_IF_NULL(GetEvalCache(eval));
|
||||
const EvalResultCache &choices = GetEvalCache(eval)->GetCache();
|
||||
auto cache = GetEvalCache(eval);
|
||||
MS_EXCEPTION_IF_NULL(cache);
|
||||
const EvalResultCache &choices = cache->GetCache();
|
||||
if (choices.get(argvals) != nullptr) {
|
||||
*result = std::make_pair(argvals, GetEvalCache(eval)->GetValue(argvals)->abstract());
|
||||
MS_EXCEPTION_IF_NULL(cache->GetValue(argvals));
|
||||
*result = std::make_pair(argvals, cache->GetValue(argvals)->abstract());
|
||||
return kSpecializeSuccess;
|
||||
} else if (choices.size() == 1) {
|
||||
MS_LOG(DEBUG) << "Evaluator cache has a single item, just use it.";
|
||||
MS_EXCEPTION_IF_NULL(choices.begin()->second);
|
||||
*result = std::make_pair(choices.begin()->first, choices.begin()->second->abstract());
|
||||
return kSpecializeSuccess;
|
||||
} else if (choices.empty()) {
|
||||
|
@ -768,11 +809,14 @@ SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunct
|
|||
}
|
||||
}
|
||||
static PrimitivePtr BuildPrimtiveValueWithAttributes(const PrimitivePtr &prim, const AttrValueMapPtr &attrs) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto &prim_attrs = prim->attrs();
|
||||
bool is_attr_same = true;
|
||||
for (auto &item : *attrs) {
|
||||
auto itr = prim_attrs.find(item.first);
|
||||
if (itr != prim_attrs.end()) {
|
||||
MS_EXCEPTION_IF_NULL(itr->second);
|
||||
MS_EXCEPTION_IF_NULL(item.second);
|
||||
if (!(*(itr->second) == *(item.second))) {
|
||||
is_attr_same = false;
|
||||
break;
|
||||
|
|
|
@ -64,6 +64,7 @@ class RemoveMonad {
|
|||
}
|
||||
|
||||
void RemoveMonadFromRandomNodes(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto &inputs = cnode->inputs();
|
||||
|
@ -79,6 +80,7 @@ class RemoveMonad {
|
|||
}
|
||||
|
||||
void RemoveRandomNodesFromMonadChain(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
const size_t first_index = 1;
|
||||
|
|
|
@ -22,11 +22,15 @@ namespace mindspore {
|
|||
namespace abstract {
|
||||
AbstractBasePtrList StackFrame::GenerateArgsAbsList(const AnalysisEnginePtr &engine, const EvaluatorPtr &evaluator,
|
||||
const CNodePtr current_cnode) {
|
||||
MS_EXCEPTION_IF_NULL(current_cnode);
|
||||
MS_EXCEPTION_IF_NULL(evaluator);
|
||||
AbstractBasePtrList args_abs_list;
|
||||
auto &inputs = current_cnode->inputs();
|
||||
for (std::size_t i = 1; i < inputs.size(); i++) {
|
||||
auto config = engine->MakeConfig(inputs[i], current_context_, current_context_->func_graph());
|
||||
auto abs = config->ObtainEvalResult()->abstract();
|
||||
auto result = config->ObtainEvalResult();
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
auto abs = result->abstract();
|
||||
args_abs_list.push_back(abs);
|
||||
}
|
||||
args_abs_list = evaluator->NormalizeArgs(args_abs_list);
|
||||
|
@ -36,6 +40,8 @@ AbstractBasePtrList StackFrame::GenerateArgsAbsList(const AnalysisEnginePtr &eng
|
|||
|
||||
AnalysisContextPtr StackFrame::GetParentContext(const BaseFuncGraphEvaluatorPtr &fg_evaluator,
|
||||
const AbstractFunctionPtr &graph_func) {
|
||||
MS_EXCEPTION_IF_NULL(graph_func);
|
||||
MS_EXCEPTION_IF_NULL(fg_evaluator);
|
||||
AnalysisContextPtr parent_context = nullptr;
|
||||
auto func_graph_abs = dyn_cast<FuncGraphAbstractClosure>(graph_func);
|
||||
if (func_graph_abs != nullptr) { // Set parent context for FuncGraphAbstractClosure.
|
||||
|
@ -55,6 +61,8 @@ AnalysisContextPtr StackFrame::GetParentContext(const BaseFuncGraphEvaluatorPtr
|
|||
// Inner jump implementation.
|
||||
StackFramePtr StackFrame::DoJump(const AnalysisEnginePtr &engine, const CNodePtr current_cnode,
|
||||
const AbstractFunctionPtr &graph_func) {
|
||||
MS_EXCEPTION_IF_NULL(engine);
|
||||
MS_EXCEPTION_IF_NULL(current_cnode);
|
||||
// Get the evaluator for func graph.
|
||||
auto evaluator = engine->GetEvaluatorFor(graph_func);
|
||||
auto fg_evaluator = dyn_cast<BaseFuncGraphEvaluator>(evaluator);
|
||||
|
@ -102,6 +110,7 @@ StackFramePtr StackFrame::DoJump(const AnalysisEnginePtr &engine, const CNodePtr
|
|||
|
||||
// Check if we need branch to another func graph.
|
||||
StackFramePtr StackFrame::Jump(const AnalysisEnginePtr &engine) {
|
||||
MS_EXCEPTION_IF_NULL(engine);
|
||||
auto ¤t_node = CurrentNode();
|
||||
if (!current_node->isa<CNode>()) {
|
||||
return nullptr;
|
||||
|
@ -126,19 +135,23 @@ StackFramePtr StackFrame::Jump(const AnalysisEnginePtr &engine) {
|
|||
|
||||
// Run one step in current func graph.
|
||||
EvalResultPtr StackFrame::Step(const AnalysisEnginePtr &engine) {
|
||||
MS_EXCEPTION_IF_NULL(engine);
|
||||
auto ¤t_node = NextNode();
|
||||
MS_LOG(DEBUG) << "current_node: " << current_node->DebugString()
|
||||
<< ", current_context_: " << current_context_->ToString();
|
||||
AnfNodeConfigPtr node_conf = engine->MakeConfig(current_node, current_context_, current_context_->func_graph());
|
||||
auto node_eval_result = engine->ObtainEvalResultWithCache(node_conf);
|
||||
MS_LOG(DEBUG) << GetInferThread() << "Eval(" << node_conf->ToString()
|
||||
<< ") = " << node_eval_result->abstract()->ToString();
|
||||
MS_LOG(DEBUG) << GetInferThread() << "Eval(" << node_conf->ToString() << ") = "
|
||||
<< (node_eval_result->abstract() ? node_eval_result->abstract()->ToString() : "Abstract null");
|
||||
return node_eval_result;
|
||||
}
|
||||
|
||||
// Return back from child func graph.
|
||||
void StackFrame::Back(const AnalysisEnginePtr &engine, const StackFramePtr &last_stack_frame,
|
||||
const EvalResultPtr &eval_result) {
|
||||
MS_EXCEPTION_IF_NULL(engine);
|
||||
MS_EXCEPTION_IF_NULL(last_stack_frame);
|
||||
MS_EXCEPTION_IF_NULL(eval_result);
|
||||
// Overwrite the result if func graph is stub.
|
||||
EvalResultPtr result = eval_result;
|
||||
if (last_stack_frame->func_graph()->stub()) {
|
||||
|
|
|
@ -80,6 +80,7 @@ size_t StackFrameDepth() { return stack_frame_depth; }
|
|||
size_t StackFrameMaxDepth() { return stack_frame_max_depth; }
|
||||
|
||||
bool IsIntermediateAbstract(const AbstractBasePtr &arg_spec) {
|
||||
MS_EXCEPTION_IF_NULL(arg_spec);
|
||||
if (dyn_cast<AbstractScalar>(arg_spec)) {
|
||||
auto v = arg_spec->GetValueTrack();
|
||||
if (v->isa<SymbolicKeyInstance>()) {
|
||||
|
@ -91,6 +92,7 @@ bool IsIntermediateAbstract(const AbstractBasePtr &arg_spec) {
|
|||
|
||||
AbstractBasePtr IntermediateJoin(const AbstractBasePtr &arg1, const AbstractBasePtr &arg2) {
|
||||
if (dyn_cast<AbstractScalar>(arg1) && dyn_cast<AbstractScalar>(arg2)) {
|
||||
MS_EXCEPTION_IF_NULL(arg1);
|
||||
return arg1->Join(arg2);
|
||||
}
|
||||
return nullptr;
|
||||
|
@ -121,6 +123,7 @@ AnalysisResult AnalysisEngine::Run(const FuncGraphPtr &func_graph, const Abstrac
|
|||
AnalysisSchedule::GetInstance().Reset();
|
||||
AnalysisResult result;
|
||||
try {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
ConfigPtrList args_conf_list;
|
||||
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(args_conf_list),
|
||||
[](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); });
|
||||
|
@ -241,42 +244,6 @@ EvalResultPtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) {
|
|||
return eval_result;
|
||||
}
|
||||
|
||||
void AnalysisEngine::CheckNoStackInSameFuncGraph(const AnfNodeConfigPtr &conf) {
|
||||
auto &list = trace::GetCNodeDebugStack();
|
||||
if (list.empty()) {
|
||||
return;
|
||||
}
|
||||
auto &previous_stack = list.back();
|
||||
MS_EXCEPTION_IF_NULL(previous_stack->node());
|
||||
MS_EXCEPTION_IF_NULL(conf->node());
|
||||
auto previous_cnode_fg = previous_stack->node()->func_graph();
|
||||
auto current_cnode_fg = conf->node()->func_graph();
|
||||
if (previous_cnode_fg != current_cnode_fg) { // Normal.
|
||||
return;
|
||||
}
|
||||
if (forward_count_ != 0) { // Ignore Forward Config.
|
||||
return;
|
||||
}
|
||||
auto &graph_stack = trace::GetCurrenGraphEvalStack();
|
||||
if (graph_stack.empty()) {
|
||||
return;
|
||||
}
|
||||
auto top_context = graph_stack.back().first;
|
||||
auto top_context_fg = top_context->func_graph();
|
||||
if (current_cnode_fg != top_context_fg) { // Ignore FV call.
|
||||
return;
|
||||
}
|
||||
MS_LOG(ERROR) << "Should not use call stack in the same function: " << top_context_fg->ToString() << ", for "
|
||||
<< conf->node()->DebugString(2);
|
||||
for (size_t i = 0; i < list.size(); ++i) {
|
||||
auto old_conf = list[i];
|
||||
MS_LOG(ERROR) << " #" << i << ": " << old_conf->node()->DebugString(2) << ", in "
|
||||
<< old_conf->context()->func_graph()->ToString();
|
||||
}
|
||||
DumpIR("use_stack_error.ir", conf->node()->func_graph());
|
||||
MS_LOG(EXCEPTION) << "To check above CNode stack and dumped use_stack_error.ir";
|
||||
}
|
||||
|
||||
AbstractBasePtr AnalysisEngine::EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf) {
|
||||
MS_EXCEPTION_IF_NULL(conf);
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
|
@ -344,6 +311,7 @@ EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConf
|
|||
}
|
||||
|
||||
EvalResultPtr AnalysisEngine::Execute(const AbstractFunctionPtr &func, const AbstractBasePtrList &args_spec_list) {
|
||||
MS_EXCEPTION_IF_NULL(func);
|
||||
ConfigPtrList args_conf_list;
|
||||
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(args_conf_list),
|
||||
[](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); });
|
||||
|
@ -580,6 +548,8 @@ EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) {
|
|||
}
|
||||
|
||||
EvalResultPtr AnalysisEngine::ForwardConfig(const AnfNodeConfigPtr &orig_conf, const AnfNodeConfigPtr new_conf) {
|
||||
MS_EXCEPTION_IF_NULL(orig_conf);
|
||||
MS_EXCEPTION_IF_NULL(new_conf);
|
||||
// Use anfnode_config_map_[orig_conf] = new_conf will require AnfNodeConfig provide copy constructor.
|
||||
(void)anfnode_config_map_.emplace(orig_conf, new_conf);
|
||||
MS_LOG(DEBUG) << "Forward orig_conf: " << orig_conf->node()->DebugString()
|
||||
|
@ -590,6 +560,7 @@ EvalResultPtr AnalysisEngine::ForwardConfig(const AnfNodeConfigPtr &orig_conf, c
|
|||
if (new_conf->node()->isa<CNode>()) {
|
||||
auto new_cnode = new_conf->node()->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(new_cnode);
|
||||
MS_EXCEPTION_IF_NULL(old_cnode->func_graph());
|
||||
if (old_cnode->func_graph() == new_cnode->func_graph()) {
|
||||
MS_LOG(DEBUG) << "Try to remove forward node from order list, forward node: " << new_cnode->ToString()
|
||||
<< ", as origin node should be in order list, origin_node: " << old_cnode->ToString();
|
||||
|
@ -622,6 +593,7 @@ EvalResultPtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr>
|
|||
}
|
||||
|
||||
void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator, const FuncGraphPtr &possible_parent_fg) {
|
||||
MS_EXCEPTION_IF_NULL(evaluator);
|
||||
static std::mutex fg_lock;
|
||||
std::lock_guard<std::mutex> infer_lock(fg_lock);
|
||||
auto fg_eval = evaluator->cast<FuncGraphEvaluatorPtr>();
|
||||
|
@ -650,6 +622,8 @@ void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator, const Fu
|
|||
EvaluatorPtr AnalysisEngine::HandleNestedRecursion(const std::vector<EvaluatorPtr> &evaluators,
|
||||
const EvaluatorPtr &eval, const AbstractBasePtrList &args_spec_list,
|
||||
const EvalTraceRevIter &it, bool *continue_flag) {
|
||||
MS_EXCEPTION_IF_NULL(continue_flag);
|
||||
MS_EXCEPTION_IF_NULL(eval);
|
||||
*continue_flag = false;
|
||||
// Find latest entry function to handle nested recursion.
|
||||
EvaluatorPtr latest_entry = eval;
|
||||
|
@ -769,6 +743,7 @@ EvalResultPtr AnalysisEngine::ProcessEvalResults(const AbstractBasePtrList &out_
|
|||
}
|
||||
|
||||
bool NeedWaitForBranches(const AbstractBasePtr &abstract) {
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
if (abstract->isa<AbstractFunction>()) {
|
||||
return true;
|
||||
}
|
||||
|
@ -860,6 +835,7 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::ve
|
|||
std::vector<AsyncAbstractPtr> branchAsyncResults;
|
||||
|
||||
for (auto &evaluator : evaluators) {
|
||||
MS_EXCEPTION_IF_NULL(evaluator);
|
||||
SetUndeterminedFlag(evaluator, possible_parent_fg);
|
||||
AsyncAbstractPtr branchAsyncResult = std::make_shared<AsyncAbstract>();
|
||||
// Control the order to run.
|
||||
|
@ -935,7 +911,6 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Evalua
|
|||
for (auto eval : evaluators) {
|
||||
MS_EXCEPTION_IF_NULL(eval);
|
||||
(void)SetUndeterminedFlag(eval, possible_parent_fg);
|
||||
|
||||
const auto current_inf = EvaluatorArgs(eval, args_spec_list);
|
||||
MS_LOG(DEBUG) << "Check Evaluator " << eval->ToString();
|
||||
// If current evaluator is under tracing, then skip current evaluator to avoid recursively evaluating.
|
||||
|
@ -1006,6 +981,7 @@ abstract::AbstractBasePtr MakeAbstractClosure(const PrimitivePtr &primitive, con
|
|||
}
|
||||
|
||||
AbstractBasePtr ToAbstract(const ValuePtr &value, const AnalysisContextPtr &context, const AnfNodeConfigPtr &conf) {
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
AnfNodePtr anf_node = nullptr;
|
||||
if (conf != nullptr) {
|
||||
anf_node = conf->node();
|
||||
|
|
|
@ -282,7 +282,6 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
|
|||
|
||||
std::unordered_map<PrimitivePyPtr, EvaluatorPtr> prim_py_evaluators_;
|
||||
|
||||
void CheckNoStackInSameFuncGraph(const AnfNodeConfigPtr &conf);
|
||||
bool enable_recursive_eval() const { return enable_recursive_eval_; }
|
||||
static EvalResultPtr ProcessEvalResults(const AbstractBasePtrList &out_specs, const AnfNodePtr &node);
|
||||
|
||||
|
|
|
@ -48,6 +48,7 @@ void ValidateOperation(const AnfNodePtr &node) {
|
|||
|
||||
// Primitive must in whitelist
|
||||
auto prim = GetValueNode<PrimitivePtr>(node);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
if (abstract::IsInWhiteList(prim)) {
|
||||
return;
|
||||
}
|
||||
|
@ -70,6 +71,7 @@ void ValidateOperation(const AnfNodePtr &node) {
|
|||
}
|
||||
|
||||
bool CheckAbstractScalar(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
AbstractBasePtr ptrBase = node->abstract();
|
||||
if (ptrBase->isa<AbstractScalar>()) {
|
||||
TypePtr ptrType = ptrBase->GetTypeTrack();
|
||||
|
|
Loading…
Reference in New Issue