!1822 fix codex split big functions

Merge pull request !1822 from fary86/codex_big_functions
This commit is contained in:
mindspore-ci-bot 2020-06-03 20:12:41 +08:00 committed by Gitee
commit 20afadb4c0
7 changed files with 234 additions and 193 deletions

View File

@ -124,6 +124,8 @@ class AnalyzedFuncGraphExporter : public AnfExporter {
void ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &func_graph);
void OutputCNodes(std::ofstream &ofs, const std::vector<AnfNodePtr> &nodes, const FuncGraphPtr &func_graph);
void OutputCNode(std::ofstream &ofs, const CNodePtr &cnode, const FuncGraphPtr &func_graph, int *idx,
std::map<AnfNodePtr, int> *const apply_map);
private:
std::string GetNodeType(const AnfNodePtr &nd) override;
@ -169,7 +171,7 @@ std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr &node) {
}
auto abs = ret->abstract();
if (abs == nullptr) {
return nullptr;
return "Undefined";
}
auto dtype = abs->BuildType();
auto shape = abs->BuildShape();
@ -247,6 +249,51 @@ AnalysisContextPtr AnalyzedFuncGraphExporter::ProcessFuncGraphCall(const CNodePt
return ctx;
}
void AnalyzedFuncGraphExporter::OutputCNode(std::ofstream &ofs, const CNodePtr &cnode, const FuncGraphPtr &func_graph,
int *idx, std::map<AnfNodePtr, int> *const apply_map) {
auto &inputs = cnode->inputs();
std::string op_text = GetAnfNodeText(func_graph, inputs[0], *apply_map);
// non-return node
if (cnode != func_graph->get_return()) {
int apply_idx = (*idx)++;
(*apply_map)[cnode] = apply_idx;
std::string type_info = GetNodeType(cnode);
if (type_info == "Undefined") {
ofs << " %" << apply_idx << " = " << op_text << "(";
} else {
ofs << " %" << apply_idx << " : " << type_info << " = " << op_text << "(";
}
} else {
ofs << " " << op_text << "(";
}
for (size_t i = 1; i < inputs.size(); ++i) {
if (i != 1) {
ofs << ", ";
}
AnfNodePtr arg = inputs[i];
ofs << GetAnfNodeText(func_graph, arg, *apply_map);
}
ofs << ")";
// process function graph call
auto ctx = ProcessFuncGraphCall(cnode);
// output comment
OutputStatementComment(ofs, cnode);
if (ctx != nullptr) {
ofs << " @ctx.addr=" << ctx.get();
}
ofs << "\n";
if (label_manage::GetGlobalTraceLabelType() == label_manage::TraceLabelType::kWithUniqueId) {
ofs << trace::GetDebugInfo(cnode->debug_info(), " # ", kSourceLineTipDiscard) << "#"
<< label_manage::Label(cnode->debug_info()) << "\n";
} else {
ofs << trace::GetDebugInfo(cnode->debug_info(), " # ", kSourceLineTipDiscard) << "\n";
}
}
void AnalyzedFuncGraphExporter::OutputCNodes(std::ofstream &ofs, const std::vector<AnfNodePtr> &nodes,
const FuncGraphPtr &func_graph) {
if (func_graph == nullptr) {
@ -267,47 +314,7 @@ void AnalyzedFuncGraphExporter::OutputCNodes(std::ofstream &ofs, const std::vect
}
auto cnode = node->cast<CNodePtr>();
auto &inputs = cnode->inputs();
std::string op_text = GetAnfNodeText(func_graph, inputs[0], apply_map);
// non-return node
if (node != func_graph->get_return()) {
int apply_idx = idx++;
apply_map[node] = apply_idx;
std::string type_info = GetNodeType(node);
if (type_info == "Undefined") {
ofs << " %" << apply_idx << " = " << op_text << "(";
} else {
ofs << " %" << apply_idx << " : " << type_info << " = " << op_text << "(";
}
} else {
ofs << " " << op_text << "(";
}
for (size_t i = 1; i < inputs.size(); ++i) {
if (i != 1) {
ofs << ", ";
}
AnfNodePtr arg = inputs[i];
ofs << GetAnfNodeText(func_graph, arg, apply_map);
}
ofs << ")";
// process function graph call
auto ctx = ProcessFuncGraphCall(cnode);
// output comment
OutputStatementComment(ofs, cnode);
if (ctx != nullptr) {
ofs << " @ctx.addr=" << ctx.get();
}
ofs << "\n";
if (label_manage::GetGlobalTraceLabelType() == label_manage::TraceLabelType::kWithUniqueId) {
ofs << trace::GetDebugInfo(cnode->debug_info(), " # ", kSourceLineTipDiscard) << "#"
<< label_manage::Label(cnode->debug_info()) << "\n";
} else {
ofs << trace::GetDebugInfo(cnode->debug_info(), " # ", kSourceLineTipDiscard) << "\n";
}
OutputCNode(ofs, cnode, func_graph, &idx, &apply_map);
}
}

View File

@ -76,44 +76,56 @@ bool CompareTensorScalarType(const TypeId &tensor_type, const size_t &t_type_num
return true;
}
void setMaxType(TypeId *max_type_id, TypeId *max_type, size_t *max_type_number, const TypeId type_id, const TypeId type,
void SetMaxType(TypeId *max_type_id, TypeId *max_type, size_t *max_type_number, const TypeId type_id, const TypeId type,
const size_t type_number) {
*max_type_id = type_id;
*max_type = type;
*max_type_number = type_number;
}
TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::vector<size_t> indexs,
const std::set<size_t> &write_indexs) {
bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId *arg_type_id,
TypeId *arg_type = nullptr) {
if (arg_value->isa<abstract::AbstractRef>()) {
if (is_write) {
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref_origin();
} else {
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref();
}
}
if (arg_value->isa<abstract::AbstractTensor>()) {
auto tensor = arg_value->cast<abstract::AbstractTensorPtr>();
auto tensor_type = tensor->element()->BuildType();
MS_EXCEPTION_IF_NULL(tensor_type);
*arg_type_id = tensor_type->type_id();
if (arg_type != nullptr) {
*arg_type = kObjectTypeTensorType;
}
return true;
}
if (arg_value->isa<abstract::AbstractScalar>()) {
auto scalar = arg_value->cast<abstract::AbstractScalarPtr>();
auto scalar_type = scalar->BuildType();
MS_EXCEPTION_IF_NULL(scalar_type);
*arg_type_id = scalar_type->type_id();
if (arg_type != nullptr) {
*arg_type = kObjectTypeNumber;
}
return true;
}
return false;
}
TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::vector<size_t> indices,
const std::set<size_t> &write_indices) {
TypeId max_type_id = kTypeUnknown;
TypeId max_type = kTypeUnknown;
size_t max_type_number = 0;
bool has_int8 = false;
for (const auto &index : indexs) {
for (const auto &index : indices) {
TypeId arg_type_id = kTypeUnknown;
TypeId arg_type = kTypeUnknown;
AbstractBasePtr arg_value = args_spec_list[index];
if (arg_value->isa<abstract::AbstractRef>()) {
auto is_write = (write_indexs.find(index) != write_indexs.end());
if (is_write) {
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref_origin();
} else {
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref();
}
}
if (arg_value->isa<abstract::AbstractTensor>()) {
auto tensor = arg_value->cast<abstract::AbstractTensorPtr>();
auto tensor_type = tensor->element()->BuildType();
MS_EXCEPTION_IF_NULL(tensor_type);
arg_type_id = tensor_type->type_id();
arg_type = kObjectTypeTensorType;
} else if (arg_value->isa<abstract::AbstractScalar>()) {
auto scalar = arg_value->cast<abstract::AbstractScalarPtr>();
auto scalar_type = scalar->BuildType();
MS_EXCEPTION_IF_NULL(scalar_type);
arg_type_id = scalar_type->type_id();
arg_type = kObjectTypeNumber;
} else {
auto is_write = (write_indices.find(index) != write_indices.end());
if (!GetTensorOrScalarTypeInfo(args_spec_list[index], is_write, &arg_type_id, &arg_type)) {
continue;
}
auto it = type_map.find(arg_type_id);
@ -124,22 +136,22 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve
has_int8 = true;
}
if (max_type_id == kTypeUnknown) {
setMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second);
SetMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second);
continue;
}
if (max_type == arg_type) {
if (it->second > max_type_number) {
setMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second);
SetMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second);
}
} else {
if (arg_type == kObjectTypeTensorType) {
if (CompareTensorScalarType(arg_type_id, it->second, max_type_id, max_type_number)) {
setMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second);
SetMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second);
}
} else {
if (!CompareTensorScalarType(max_type_id, max_type_number, arg_type_id, it->second)) {
setMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second);
SetMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second);
}
}
}
@ -154,28 +166,28 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve
// Get the largest type of index in the same SignatureEnumDType of arguments.
std::map<SignatureEnumDType, TypeId> GetMaxDtype(const std::vector<SignatureEnumDType> &dtypes,
const abstract::AbstractBasePtrList &args_spec_list,
const std::set<size_t> &write_indexs) {
const std::set<size_t> &write_indices) {
// record index for signature.dtypes of the same type
// eg. [T, T1, T, T2, T, T1, T3] -> {{T:(0,2,4)}, {T1:(1,5)}, {T2:(3)}, {T3:(6)}}
std::map<SignatureEnumDType, std::vector<size_t>> type_indexs;
std::map<SignatureEnumDType, std::vector<size_t>> type_indices;
for (size_t i = 0; i < dtypes.size(); ++i) {
auto it = type_indexs.find(dtypes[i]);
if (it == type_indexs.end()) {
(void)type_indexs.insert(std::make_pair(dtypes[i], std::vector<size_t>{i}));
auto it = type_indices.find(dtypes[i]);
if (it == type_indices.end()) {
(void)type_indices.insert(std::make_pair(dtypes[i], std::vector<size_t>{i}));
} else {
it->second.push_back(i);
}
}
std::map<SignatureEnumDType, TypeId> dst_type;
for (auto it = type_indexs.begin(); it != type_indexs.end(); (void)++it) {
for (auto it = type_indices.begin(); it != type_indices.end(); (void)++it) {
auto type = it->first;
auto indexs = it->second;
auto indices = it->second;
// If the number of arguments belonging to the same SignatureEnumDType is less than 2, skip it.
if (indexs.size() < 2) {
if (indices.size() < 2) {
continue;
}
bool has_tensor = false;
for (const auto &index : indexs) {
for (const auto &index : indices) {
AbstractBasePtr arg_value = args_spec_list[index];
if (arg_value->isa<abstract::AbstractRef>()) {
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref();
@ -189,7 +201,7 @@ std::map<SignatureEnumDType, TypeId> GetMaxDtype(const std::vector<SignatureEnum
(void)dst_type.insert(std::make_pair(type, kTypeUnknown));
continue;
}
(void)dst_type.insert(std::make_pair(type, GetMaxTypeId(args_spec_list, indexs, write_indexs)));
(void)dst_type.insert(std::make_pair(type, GetMaxTypeId(args_spec_list, indices, write_indices)));
}
return dst_type;
}
@ -204,7 +216,7 @@ AnfNodePtr DoCast(const AnfNodePtr &param, const TypeId &type_id, const FuncGrap
void DoAutoCast(const std::string &func_name, const std::vector<Signature> &signature,
const abstract::AbstractBasePtrList &args_spec_list, const FuncGraphPtr &graph,
std::vector<AnfNodePtr> *const op_inputs, const std::set<size_t> &write_indexs) {
std::vector<AnfNodePtr> *const op_inputs, const std::set<size_t> &write_indices) {
std::vector<SignatureEnumDType> dtypes;
(void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes),
[](const Signature &sig) { return sig.dtype; });
@ -213,36 +225,19 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign
return;
}
// Stat the index of the arguments with the largest type in the same SignatureEnumDType.
std::map<SignatureEnumDType, TypeId> dst_type = GetMaxDtype(dtypes, args_spec_list, write_indexs);
std::map<SignatureEnumDType, TypeId> dst_type = GetMaxDtype(dtypes, args_spec_list, write_indices);
// Identify which arg requires auto cast
for (size_t i = 0; i < args_spec_list.size(); ++i) {
auto it = dst_type.find(dtypes[i]);
if (it == dst_type.end() || it->second == kTypeUnknown) {
continue;
}
auto rw_it = write_indexs.find(i);
auto is_write = (rw_it != write_indexs.end());
auto rw_it = write_indices.find(i);
auto is_write = (rw_it != write_indices.end());
AbstractBasePtr arg_value = args_spec_list[i];
if (arg_value->isa<abstract::AbstractRef>()) {
if (is_write) {
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref_origin();
} else {
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref();
}
}
TypeId arg_type_id = kTypeUnknown;
if (arg_value->isa<abstract::AbstractTensor>()) {
auto tensor = arg_value->cast<abstract::AbstractTensorPtr>();
auto tensor_type = tensor->element()->BuildType();
MS_EXCEPTION_IF_NULL(tensor_type);
arg_type_id = tensor_type->type_id();
} else if (arg_value->isa<abstract::AbstractScalar>()) {
auto scalar = arg_value->cast<abstract::AbstractScalarPtr>();
auto scalar_type = scalar->BuildType();
MS_EXCEPTION_IF_NULL(scalar_type);
arg_type_id = scalar_type->type_id();
}
AbstractBasePtr arg_value = args_spec_list[i];
(void)GetTensorOrScalarTypeInfo(arg_value, is_write, &arg_type_id);
auto it_map = type_map.find(arg_type_id);
if (it_map == type_map.end()) {
continue;
@ -279,7 +274,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
}
}
std::vector<AnfNodePtr> op_inputs;
std::set<size_t> write_indexs;
std::set<size_t> write_indices;
op_inputs.push_back(NewValueNode(function));
// Assume, the write input of op is always the first input. We check if any write op,
// and add cast op on other inputs to keep the same type with assigned parameter.
@ -303,7 +298,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefValue), param});
} else if (sig == SignatureEnumRW::kRWWrite) {
param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefOrigin), param});
write_indexs.insert(i);
write_indices.insert(i);
}
// If sig is SignatureEnumRW::kRWRef, not do anything.
} else if (sig == SignatureEnumRW::kRWWrite && type->type_id() != kObjectTypeRefKey) {
@ -313,7 +308,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
}
// process default
ProcessDefault(func_name, args_spec_list, signature, has_var, &op_inputs);
DoAutoCast(func_name, signature, args_spec_list, func_graph, &op_inputs, write_indexs);
DoAutoCast(func_name, signature, args_spec_list, func_graph, &op_inputs, write_indices);
return func_graph->NewCNode(op_inputs);
}
} // namespace

View File

@ -238,6 +238,31 @@ FuncGraphPtr ConvertToBpropCut(py::object obj) {
return bprop_graph;
}
bool ConvertCellObjToFuncGraph(py::object obj, ValuePtr *const data) {
FuncGraphPtr func_graph = ConvertToFuncGraph(obj);
if (func_graph == nullptr) {
MS_LOG(ERROR) << "Parse resolve function error.";
return false;
}
// if the cell object has specified bprop, it has user-defined bprop function parse and record it
if (py::hasattr(obj, "bprop")) {
FuncGraphPtr bprop_graph = nullptr;
bool enable_bprop_debug = py::cast<bool>(py::getattr(obj, "bprop_debug"));
if (enable_bprop_debug) {
bprop_graph = ConvertToBpropCut(obj);
} else {
bprop_graph = ConvertToFuncGraph(obj, PYTHON_MOD_GET_BPROP_METHOD);
}
if (bprop_graph != nullptr) {
(void)func_graph->transforms().insert(std::make_pair("bprop", FuncGraphTransform(bprop_graph)));
(void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(func_graph)));
func_graph->set_flags(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
}
}
*data = func_graph;
return true;
}
bool ConvertOtherObj(py::object obj, ValuePtr *const data) {
auto obj_type = data_converter::GetObjType(obj);
MS_LOG(DEBUG) << "Converting the object(" << ((std::string)py::str(obj)) << ") detail type: " << obj_type << " ";
@ -262,32 +287,12 @@ bool ConvertOtherObj(py::object obj, ValuePtr *const data) {
// Create the namespace for common class instance
// When the obj is Cell, default parse the 'construct'
if (data_converter::IsCellInstance(obj)) {
FuncGraphPtr func_graph = ConvertToFuncGraph(obj);
if (func_graph == nullptr) {
MS_LOG(ERROR) << "Parse resolve function error.";
return false;
}
// if the cell object has specified bprop, it has user-defined bprop function parse and record it
if (py::hasattr(obj, "bprop")) {
FuncGraphPtr bprop_graph = nullptr;
bool enable_bprop_debug = py::cast<bool>(py::getattr(obj, "bprop_debug"));
if (enable_bprop_debug) {
bprop_graph = ConvertToBpropCut(obj);
} else {
bprop_graph = ConvertToFuncGraph(obj, PYTHON_MOD_GET_BPROP_METHOD);
}
if (bprop_graph != nullptr) {
(void)func_graph->transforms().insert(std::make_pair("bprop", FuncGraphTransform(bprop_graph)));
(void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(func_graph)));
func_graph->set_flags(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
}
}
*data = func_graph;
} else {
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
py::object namespace_var = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, obj);
*data = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var);
return ConvertCellObjToFuncGraph(obj, data);
}
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
py::object namespace_var = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, obj);
*data = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var);
return true;
}
MS_LOG(ERROR) << "Resolve type is invalid " << ((std::string)py::str(obj));

View File

@ -608,7 +608,7 @@ void Pipeline::Run() {
MS_LOG(INFO) << "End";
}
void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef *arg_list) {
void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef *const arg_list) {
std::size_t size = args.size();
for (std::size_t i = 0; i < size; i++) {

View File

@ -139,7 +139,7 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc
const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes,
const std::vector<int64_t> &input_indexes, bool need_run);
void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef *arg_list);
void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef *const arg_list);
} // namespace pipeline
} // namespace mindspore

View File

@ -464,6 +464,85 @@ EvalResultPtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr>
return ExecuteMultipleEvaluators(evaluators, out_conf, args_conf_list);
}
void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator) {
auto fg_eval = evaluator->cast<FuncGraphEvaluatorPtr>();
if (fg_eval == nullptr) {
return;
}
auto fg = fg_eval->func_graph();
MS_EXCEPTION_IF_NULL(fg);
auto undetermined_fgs = fg->recursive_graphs();
if (undetermined_fgs) {
auto fg_parent = fg->parent();
MS_EXCEPTION_IF_NULL(fg_parent);
fg_parent->set_flags(kFuncGraphFlagUndetermined, true);
MS_LOG(DEBUG) << "Set graph undetermined: " << fg_parent->ToString();
}
}
EvaluatorPtr AnalysisEngine::HandleNestedRecursion(const std::vector<EvaluatorPtr> &evaluators,
const EvaluatorPtr &eval, const AbstractBasePtrList &args_spec_list,
const EvalTraceRevIter &it, bool *continue_flag) {
*continue_flag = false;
// Find latest entry function to handle nested recursion.
EvaluatorPtr latest_entry = eval;
auto latest_entry_iter = eval_trace_.rbegin();
for (auto r_it = eval_trace_.rbegin(); *r_it != *it;) {
auto it_temp = std::find(evaluators.begin(), evaluators.end(), r_it->first);
if (it_temp != evaluators.end()) {
latest_entry = *it_temp;
latest_entry_iter = r_it;
break;
}
latest_entry_iter = ++r_it;
}
if (latest_entry != eval) {
MS_LOG(DEBUG) << "Continue Evaluator " << eval->ToString();
*continue_flag = true;
return latest_entry;
}
bool has_undetermined = false;
// Check whether sub loop has untraced undetermined evaluator.
std::set<std::pair<EvaluatorPtr, AbstractBasePtrList>> undetermined_evals;
for (auto r_it = eval_trace_.rbegin(); r_it != latest_entry_iter; r_it++) {
undetermined_evals.insert(*r_it);
}
MS_LOG(DEBUG) << "undetermined_evals size(): " << undetermined_evals.size();
for (auto u_eval : undetermined_evals) {
MS_LOG(DEBUG) << u_eval.first->ToString() << " check undetermined.";
if (!undetermined_evals.count(std::make_pair(multi_poss_[u_eval.first], args_spec_list))) {
MS_LOG(DEBUG) << u_eval.first->ToString() << " has undetermined.";
has_undetermined = true;
break;
}
}
if (has_undetermined == false) {
MS_LOG(DEBUG) << eval->ToString() << " has no undetermined.";
*continue_flag = true;
return latest_entry;
}
return latest_entry;
}
EvalResultPtr AnalysisEngine::ProcessEvalResults(const AbstractBasePtrList &out_specs) {
if (out_specs.size() == 0) {
MS_LOG(EXCEPTION) << "There is an endless loop for evaluator.";
}
if (out_specs.size() == 1) {
MS_EXCEPTION_IF_NULL(out_specs[0]);
// If only one result derived, then broaden it to avoid wrong constant propagation.
return std::make_shared<EvalResult>(out_specs[0]->Broaden(), std::make_shared<AttrValueMap>());
}
auto joined_spec = AbstractJoin(out_specs);
MS_EXCEPTION_IF_NULL(joined_spec);
MS_LOG(DEBUG) << "Multiple evaluators joined: " << joined_spec->ToString();
return std::make_shared<EvalResult>(joined_spec, std::make_shared<AttrValueMap>());
}
EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators,
const AnfNodeConfigPtr &out_conf,
const ConfigPtrList &args_conf_list) {
@ -479,18 +558,7 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Evalua
return conf->GetEvaluatedValue()->abstract();
});
for (auto eval : evaluators) {
auto fg_eval = eval->cast<FuncGraphEvaluatorPtr>();
if (fg_eval) {
auto fg = fg_eval->func_graph();
MS_EXCEPTION_IF_NULL(fg);
auto undetermined_fgs = fg->recursive_graphs();
if (undetermined_fgs) {
auto fg_parent = fg->parent();
MS_EXCEPTION_IF_NULL(fg_parent);
fg_parent->set_flags(kFuncGraphFlagUndetermined, true);
MS_LOG(DEBUG) << "Set graph undetermined: " << fg_parent->ToString();
}
}
SetUndeterminedFlag(eval);
auto current_inf = std::make_pair(eval, args_spec_list);
MS_LOG(DEBUG) << "Check Evaluator " << eval->ToString();
@ -510,40 +578,9 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Evalua
multi_poss_.clear();
}
} else if (it != eval_trace_.rbegin()) {
// Find latest entry function to handle nested recursion.
EvaluatorPtr latest_entry = eval;
auto latest_entry_iter = eval_trace_.rbegin();
for (auto r_it = eval_trace_.rbegin(); *r_it != *it;) {
auto it_temp = std::find(evaluators.begin(), evaluators.end(), r_it->first);
if (it_temp != evaluators.end()) {
latest_entry = *it_temp;
latest_entry_iter = r_it;
break;
}
latest_entry_iter = ++r_it;
}
if (latest_entry != eval) {
MS_LOG(DEBUG) << "Continue Evaluator " << eval->ToString();
continue;
}
bool has_undetermined = false;
// Check whether sub loop has untraced undetermined evaluator.
std::set<std::pair<EvaluatorPtr, AbstractBasePtrList>> undetermined_evals;
for (auto r_it = eval_trace_.rbegin(); r_it != latest_entry_iter; r_it++) {
undetermined_evals.insert(*r_it);
}
MS_LOG(DEBUG) << "undetermined_evals size(): " << undetermined_evals.size();
for (auto u_eval : undetermined_evals) {
MS_LOG(DEBUG) << u_eval.first->ToString() << " check undetermined.";
if (!undetermined_evals.count(std::make_pair(multi_poss_[u_eval.first], args_spec_list))) {
MS_LOG(DEBUG) << u_eval.first->ToString() << " has undetermined.";
has_undetermined = true;
break;
}
}
if (has_undetermined == false) {
MS_LOG(DEBUG) << eval->ToString() << " has no undetermined.";
bool continue_flag = false;
auto latest_entry = HandleNestedRecursion(evaluators, eval, args_spec_list, it, &continue_flag);
if (continue_flag) {
continue;
}
@ -558,19 +595,8 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Evalua
}
}
}
if (out_specs.size() == 0) {
MS_LOG(EXCEPTION) << "There is an endless loop for evaluator.";
}
if (out_specs.size() == 1) {
MS_EXCEPTION_IF_NULL(out_specs[0]);
// If only one result derived, then broaden it to avoid wrong constant propagation.
return std::make_shared<EvalResult>(out_specs[0]->Broaden(), std::make_shared<AttrValueMap>());
}
auto joined_spec = AbstractJoin(out_specs);
MS_EXCEPTION_IF_NULL(joined_spec);
MS_LOG(DEBUG) << "Multiple evaluators joined: " << joined_spec->ToString();
return std::make_shared<EvalResult>(joined_spec, std::make_shared<AttrValueMap>());
return ProcessEvalResults(out_specs);
}
EvalResultPtr AnfNodeConfig::GetEvaluatedValue() {

View File

@ -172,6 +172,8 @@ struct AnalysisResult {
AnalysisContextPtr context;
};
using EvalTraceRevIter = std::list<std::pair<EvaluatorPtr, AbstractBasePtrList>>::reverse_iterator;
class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
public:
AnalysisEngine(const PrimEvaluatorMap &prim_evaluator_map, const FuncGraphManagerPtr &func_graph_manager)
@ -222,6 +224,12 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
std::unordered_map<PrimitivePyPtr, EvaluatorPtr> prim_py_evaluators_;
private:
void SetUndeterminedFlag(const EvaluatorPtr &evaluator);
EvaluatorPtr HandleNestedRecursion(const std::vector<EvaluatorPtr> &evaluators, const EvaluatorPtr &eval,
const AbstractBasePtrList &args_spec_list, const EvalTraceRevIter &it,
bool *continue_flag);
EvalResultPtr ProcessEvalResults(const AbstractBasePtrList &out_specs);
const PrimEvaluatorMap &prim_constructors_;
FuncGraphManagerPtr func_graph_manager_;
std::unordered_map<AbstractFunctionPtr, EvaluatorPtr> constructors_;