!1822 fix codex split big functions

Merge pull request !1822 from fary86/codex_big_functions
This commit is contained in:
mindspore-ci-bot 2020-06-03 20:12:41 +08:00 committed by Gitee
commit 20afadb4c0
7 changed files with 234 additions and 193 deletions

View File

@ -124,6 +124,8 @@ class AnalyzedFuncGraphExporter : public AnfExporter {
void ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &func_graph); void ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &func_graph);
void OutputCNodes(std::ofstream &ofs, const std::vector<AnfNodePtr> &nodes, const FuncGraphPtr &func_graph); void OutputCNodes(std::ofstream &ofs, const std::vector<AnfNodePtr> &nodes, const FuncGraphPtr &func_graph);
void OutputCNode(std::ofstream &ofs, const CNodePtr &cnode, const FuncGraphPtr &func_graph, int *idx,
std::map<AnfNodePtr, int> *const apply_map);
private: private:
std::string GetNodeType(const AnfNodePtr &nd) override; std::string GetNodeType(const AnfNodePtr &nd) override;
@ -169,7 +171,7 @@ std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr &node) {
} }
auto abs = ret->abstract(); auto abs = ret->abstract();
if (abs == nullptr) { if (abs == nullptr) {
return nullptr; return "Undefined";
} }
auto dtype = abs->BuildType(); auto dtype = abs->BuildType();
auto shape = abs->BuildShape(); auto shape = abs->BuildShape();
@ -247,6 +249,51 @@ AnalysisContextPtr AnalyzedFuncGraphExporter::ProcessFuncGraphCall(const CNodePt
return ctx; return ctx;
} }
void AnalyzedFuncGraphExporter::OutputCNode(std::ofstream &ofs, const CNodePtr &cnode, const FuncGraphPtr &func_graph,
int *idx, std::map<AnfNodePtr, int> *const apply_map) {
auto &inputs = cnode->inputs();
std::string op_text = GetAnfNodeText(func_graph, inputs[0], *apply_map);
// non-return node
if (cnode != func_graph->get_return()) {
int apply_idx = (*idx)++;
(*apply_map)[cnode] = apply_idx;
std::string type_info = GetNodeType(cnode);
if (type_info == "Undefined") {
ofs << " %" << apply_idx << " = " << op_text << "(";
} else {
ofs << " %" << apply_idx << " : " << type_info << " = " << op_text << "(";
}
} else {
ofs << " " << op_text << "(";
}
for (size_t i = 1; i < inputs.size(); ++i) {
if (i != 1) {
ofs << ", ";
}
AnfNodePtr arg = inputs[i];
ofs << GetAnfNodeText(func_graph, arg, *apply_map);
}
ofs << ")";
// process function graph call
auto ctx = ProcessFuncGraphCall(cnode);
// output comment
OutputStatementComment(ofs, cnode);
if (ctx != nullptr) {
ofs << " @ctx.addr=" << ctx.get();
}
ofs << "\n";
if (label_manage::GetGlobalTraceLabelType() == label_manage::TraceLabelType::kWithUniqueId) {
ofs << trace::GetDebugInfo(cnode->debug_info(), " # ", kSourceLineTipDiscard) << "#"
<< label_manage::Label(cnode->debug_info()) << "\n";
} else {
ofs << trace::GetDebugInfo(cnode->debug_info(), " # ", kSourceLineTipDiscard) << "\n";
}
}
void AnalyzedFuncGraphExporter::OutputCNodes(std::ofstream &ofs, const std::vector<AnfNodePtr> &nodes, void AnalyzedFuncGraphExporter::OutputCNodes(std::ofstream &ofs, const std::vector<AnfNodePtr> &nodes,
const FuncGraphPtr &func_graph) { const FuncGraphPtr &func_graph) {
if (func_graph == nullptr) { if (func_graph == nullptr) {
@ -267,47 +314,7 @@ void AnalyzedFuncGraphExporter::OutputCNodes(std::ofstream &ofs, const std::vect
} }
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
auto &inputs = cnode->inputs(); OutputCNode(ofs, cnode, func_graph, &idx, &apply_map);
std::string op_text = GetAnfNodeText(func_graph, inputs[0], apply_map);
// non-return node
if (node != func_graph->get_return()) {
int apply_idx = idx++;
apply_map[node] = apply_idx;
std::string type_info = GetNodeType(node);
if (type_info == "Undefined") {
ofs << " %" << apply_idx << " = " << op_text << "(";
} else {
ofs << " %" << apply_idx << " : " << type_info << " = " << op_text << "(";
}
} else {
ofs << " " << op_text << "(";
}
for (size_t i = 1; i < inputs.size(); ++i) {
if (i != 1) {
ofs << ", ";
}
AnfNodePtr arg = inputs[i];
ofs << GetAnfNodeText(func_graph, arg, apply_map);
}
ofs << ")";
// process function graph call
auto ctx = ProcessFuncGraphCall(cnode);
// output comment
OutputStatementComment(ofs, cnode);
if (ctx != nullptr) {
ofs << " @ctx.addr=" << ctx.get();
}
ofs << "\n";
if (label_manage::GetGlobalTraceLabelType() == label_manage::TraceLabelType::kWithUniqueId) {
ofs << trace::GetDebugInfo(cnode->debug_info(), " # ", kSourceLineTipDiscard) << "#"
<< label_manage::Label(cnode->debug_info()) << "\n";
} else {
ofs << trace::GetDebugInfo(cnode->debug_info(), " # ", kSourceLineTipDiscard) << "\n";
}
} }
} }

View File

@ -76,25 +76,16 @@ bool CompareTensorScalarType(const TypeId &tensor_type, const size_t &t_type_num
return true; return true;
} }
void setMaxType(TypeId *max_type_id, TypeId *max_type, size_t *max_type_number, const TypeId type_id, const TypeId type, void SetMaxType(TypeId *max_type_id, TypeId *max_type, size_t *max_type_number, const TypeId type_id, const TypeId type,
const size_t type_number) { const size_t type_number) {
*max_type_id = type_id; *max_type_id = type_id;
*max_type = type; *max_type = type;
*max_type_number = type_number; *max_type_number = type_number;
} }
TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::vector<size_t> indexs, bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId *arg_type_id,
const std::set<size_t> &write_indexs) { TypeId *arg_type = nullptr) {
TypeId max_type_id = kTypeUnknown;
TypeId max_type = kTypeUnknown;
size_t max_type_number = 0;
bool has_int8 = false;
for (const auto &index : indexs) {
TypeId arg_type_id = kTypeUnknown;
TypeId arg_type = kTypeUnknown;
AbstractBasePtr arg_value = args_spec_list[index];
if (arg_value->isa<abstract::AbstractRef>()) { if (arg_value->isa<abstract::AbstractRef>()) {
auto is_write = (write_indexs.find(index) != write_indexs.end());
if (is_write) { if (is_write) {
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref_origin(); arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref_origin();
} else { } else {
@ -105,15 +96,36 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve
auto tensor = arg_value->cast<abstract::AbstractTensorPtr>(); auto tensor = arg_value->cast<abstract::AbstractTensorPtr>();
auto tensor_type = tensor->element()->BuildType(); auto tensor_type = tensor->element()->BuildType();
MS_EXCEPTION_IF_NULL(tensor_type); MS_EXCEPTION_IF_NULL(tensor_type);
arg_type_id = tensor_type->type_id(); *arg_type_id = tensor_type->type_id();
arg_type = kObjectTypeTensorType; if (arg_type != nullptr) {
} else if (arg_value->isa<abstract::AbstractScalar>()) { *arg_type = kObjectTypeTensorType;
}
return true;
}
if (arg_value->isa<abstract::AbstractScalar>()) {
auto scalar = arg_value->cast<abstract::AbstractScalarPtr>(); auto scalar = arg_value->cast<abstract::AbstractScalarPtr>();
auto scalar_type = scalar->BuildType(); auto scalar_type = scalar->BuildType();
MS_EXCEPTION_IF_NULL(scalar_type); MS_EXCEPTION_IF_NULL(scalar_type);
arg_type_id = scalar_type->type_id(); *arg_type_id = scalar_type->type_id();
arg_type = kObjectTypeNumber; if (arg_type != nullptr) {
} else { *arg_type = kObjectTypeNumber;
}
return true;
}
return false;
}
TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::vector<size_t> indices,
const std::set<size_t> &write_indices) {
TypeId max_type_id = kTypeUnknown;
TypeId max_type = kTypeUnknown;
size_t max_type_number = 0;
bool has_int8 = false;
for (const auto &index : indices) {
TypeId arg_type_id = kTypeUnknown;
TypeId arg_type = kTypeUnknown;
auto is_write = (write_indices.find(index) != write_indices.end());
if (!GetTensorOrScalarTypeInfo(args_spec_list[index], is_write, &arg_type_id, &arg_type)) {
continue; continue;
} }
auto it = type_map.find(arg_type_id); auto it = type_map.find(arg_type_id);
@ -124,22 +136,22 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve
has_int8 = true; has_int8 = true;
} }
if (max_type_id == kTypeUnknown) { if (max_type_id == kTypeUnknown) {
setMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second); SetMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second);
continue; continue;
} }
if (max_type == arg_type) { if (max_type == arg_type) {
if (it->second > max_type_number) { if (it->second > max_type_number) {
setMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second); SetMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second);
} }
} else { } else {
if (arg_type == kObjectTypeTensorType) { if (arg_type == kObjectTypeTensorType) {
if (CompareTensorScalarType(arg_type_id, it->second, max_type_id, max_type_number)) { if (CompareTensorScalarType(arg_type_id, it->second, max_type_id, max_type_number)) {
setMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second); SetMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second);
} }
} else { } else {
if (!CompareTensorScalarType(max_type_id, max_type_number, arg_type_id, it->second)) { if (!CompareTensorScalarType(max_type_id, max_type_number, arg_type_id, it->second)) {
setMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second); SetMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second);
} }
} }
} }
@ -154,28 +166,28 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve
// Get the largest type of index in the same SignatureEnumDType of arguments. // Get the largest type of index in the same SignatureEnumDType of arguments.
std::map<SignatureEnumDType, TypeId> GetMaxDtype(const std::vector<SignatureEnumDType> &dtypes, std::map<SignatureEnumDType, TypeId> GetMaxDtype(const std::vector<SignatureEnumDType> &dtypes,
const abstract::AbstractBasePtrList &args_spec_list, const abstract::AbstractBasePtrList &args_spec_list,
const std::set<size_t> &write_indexs) { const std::set<size_t> &write_indices) {
// record index for signature.dtypes of the same type // record index for signature.dtypes of the same type
// eg. [T, T1, T, T2, T, T1, T3] -> {{T:(0,2,4)}, {T1:(1,5)}, {T2:(3)}, {T3:(6)}} // eg. [T, T1, T, T2, T, T1, T3] -> {{T:(0,2,4)}, {T1:(1,5)}, {T2:(3)}, {T3:(6)}}
std::map<SignatureEnumDType, std::vector<size_t>> type_indexs; std::map<SignatureEnumDType, std::vector<size_t>> type_indices;
for (size_t i = 0; i < dtypes.size(); ++i) { for (size_t i = 0; i < dtypes.size(); ++i) {
auto it = type_indexs.find(dtypes[i]); auto it = type_indices.find(dtypes[i]);
if (it == type_indexs.end()) { if (it == type_indices.end()) {
(void)type_indexs.insert(std::make_pair(dtypes[i], std::vector<size_t>{i})); (void)type_indices.insert(std::make_pair(dtypes[i], std::vector<size_t>{i}));
} else { } else {
it->second.push_back(i); it->second.push_back(i);
} }
} }
std::map<SignatureEnumDType, TypeId> dst_type; std::map<SignatureEnumDType, TypeId> dst_type;
for (auto it = type_indexs.begin(); it != type_indexs.end(); (void)++it) { for (auto it = type_indices.begin(); it != type_indices.end(); (void)++it) {
auto type = it->first; auto type = it->first;
auto indexs = it->second; auto indices = it->second;
// If the number of arguments belonging to the same SignatureEnumDType is less than 2, skip it. // If the number of arguments belonging to the same SignatureEnumDType is less than 2, skip it.
if (indexs.size() < 2) { if (indices.size() < 2) {
continue; continue;
} }
bool has_tensor = false; bool has_tensor = false;
for (const auto &index : indexs) { for (const auto &index : indices) {
AbstractBasePtr arg_value = args_spec_list[index]; AbstractBasePtr arg_value = args_spec_list[index];
if (arg_value->isa<abstract::AbstractRef>()) { if (arg_value->isa<abstract::AbstractRef>()) {
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref(); arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref();
@ -189,7 +201,7 @@ std::map<SignatureEnumDType, TypeId> GetMaxDtype(const std::vector<SignatureEnum
(void)dst_type.insert(std::make_pair(type, kTypeUnknown)); (void)dst_type.insert(std::make_pair(type, kTypeUnknown));
continue; continue;
} }
(void)dst_type.insert(std::make_pair(type, GetMaxTypeId(args_spec_list, indexs, write_indexs))); (void)dst_type.insert(std::make_pair(type, GetMaxTypeId(args_spec_list, indices, write_indices)));
} }
return dst_type; return dst_type;
} }
@ -204,7 +216,7 @@ AnfNodePtr DoCast(const AnfNodePtr &param, const TypeId &type_id, const FuncGrap
void DoAutoCast(const std::string &func_name, const std::vector<Signature> &signature, void DoAutoCast(const std::string &func_name, const std::vector<Signature> &signature,
const abstract::AbstractBasePtrList &args_spec_list, const FuncGraphPtr &graph, const abstract::AbstractBasePtrList &args_spec_list, const FuncGraphPtr &graph,
std::vector<AnfNodePtr> *const op_inputs, const std::set<size_t> &write_indexs) { std::vector<AnfNodePtr> *const op_inputs, const std::set<size_t> &write_indices) {
std::vector<SignatureEnumDType> dtypes; std::vector<SignatureEnumDType> dtypes;
(void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes), (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes),
[](const Signature &sig) { return sig.dtype; }); [](const Signature &sig) { return sig.dtype; });
@ -213,36 +225,19 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign
return; return;
} }
// Stat the index of the arguments with the largest type in the same SignatureEnumDType. // Stat the index of the arguments with the largest type in the same SignatureEnumDType.
std::map<SignatureEnumDType, TypeId> dst_type = GetMaxDtype(dtypes, args_spec_list, write_indexs); std::map<SignatureEnumDType, TypeId> dst_type = GetMaxDtype(dtypes, args_spec_list, write_indices);
// Identify which arg requires auto cast // Identify which arg requires auto cast
for (size_t i = 0; i < args_spec_list.size(); ++i) { for (size_t i = 0; i < args_spec_list.size(); ++i) {
auto it = dst_type.find(dtypes[i]); auto it = dst_type.find(dtypes[i]);
if (it == dst_type.end() || it->second == kTypeUnknown) { if (it == dst_type.end() || it->second == kTypeUnknown) {
continue; continue;
} }
auto rw_it = write_indexs.find(i); auto rw_it = write_indices.find(i);
auto is_write = (rw_it != write_indexs.end()); auto is_write = (rw_it != write_indices.end());
AbstractBasePtr arg_value = args_spec_list[i];
if (arg_value->isa<abstract::AbstractRef>()) {
if (is_write) {
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref_origin();
} else {
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref();
}
}
TypeId arg_type_id = kTypeUnknown; TypeId arg_type_id = kTypeUnknown;
if (arg_value->isa<abstract::AbstractTensor>()) { AbstractBasePtr arg_value = args_spec_list[i];
auto tensor = arg_value->cast<abstract::AbstractTensorPtr>(); (void)GetTensorOrScalarTypeInfo(arg_value, is_write, &arg_type_id);
auto tensor_type = tensor->element()->BuildType();
MS_EXCEPTION_IF_NULL(tensor_type);
arg_type_id = tensor_type->type_id();
} else if (arg_value->isa<abstract::AbstractScalar>()) {
auto scalar = arg_value->cast<abstract::AbstractScalarPtr>();
auto scalar_type = scalar->BuildType();
MS_EXCEPTION_IF_NULL(scalar_type);
arg_type_id = scalar_type->type_id();
}
auto it_map = type_map.find(arg_type_id); auto it_map = type_map.find(arg_type_id);
if (it_map == type_map.end()) { if (it_map == type_map.end()) {
continue; continue;
@ -279,7 +274,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
} }
} }
std::vector<AnfNodePtr> op_inputs; std::vector<AnfNodePtr> op_inputs;
std::set<size_t> write_indexs; std::set<size_t> write_indices;
op_inputs.push_back(NewValueNode(function)); op_inputs.push_back(NewValueNode(function));
// Assume, the write input of op is always the first input. We check if any write op, // Assume, the write input of op is always the first input. We check if any write op,
// and add cast op on other inputs to keep the same type with assigned parameter. // and add cast op on other inputs to keep the same type with assigned parameter.
@ -303,7 +298,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefValue), param}); param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefValue), param});
} else if (sig == SignatureEnumRW::kRWWrite) { } else if (sig == SignatureEnumRW::kRWWrite) {
param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefOrigin), param}); param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefOrigin), param});
write_indexs.insert(i); write_indices.insert(i);
} }
// If sig is SignatureEnumRW::kRWRef, not do anything. // If sig is SignatureEnumRW::kRWRef, not do anything.
} else if (sig == SignatureEnumRW::kRWWrite && type->type_id() != kObjectTypeRefKey) { } else if (sig == SignatureEnumRW::kRWWrite && type->type_id() != kObjectTypeRefKey) {
@ -313,7 +308,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
} }
// process default // process default
ProcessDefault(func_name, args_spec_list, signature, has_var, &op_inputs); ProcessDefault(func_name, args_spec_list, signature, has_var, &op_inputs);
DoAutoCast(func_name, signature, args_spec_list, func_graph, &op_inputs, write_indexs); DoAutoCast(func_name, signature, args_spec_list, func_graph, &op_inputs, write_indices);
return func_graph->NewCNode(op_inputs); return func_graph->NewCNode(op_inputs);
} }
} // namespace } // namespace

View File

@ -238,6 +238,31 @@ FuncGraphPtr ConvertToBpropCut(py::object obj) {
return bprop_graph; return bprop_graph;
} }
bool ConvertCellObjToFuncGraph(py::object obj, ValuePtr *const data) {
FuncGraphPtr func_graph = ConvertToFuncGraph(obj);
if (func_graph == nullptr) {
MS_LOG(ERROR) << "Parse resolve function error.";
return false;
}
// if the cell object has specified bprop, it has user-defined bprop function parse and record it
if (py::hasattr(obj, "bprop")) {
FuncGraphPtr bprop_graph = nullptr;
bool enable_bprop_debug = py::cast<bool>(py::getattr(obj, "bprop_debug"));
if (enable_bprop_debug) {
bprop_graph = ConvertToBpropCut(obj);
} else {
bprop_graph = ConvertToFuncGraph(obj, PYTHON_MOD_GET_BPROP_METHOD);
}
if (bprop_graph != nullptr) {
(void)func_graph->transforms().insert(std::make_pair("bprop", FuncGraphTransform(bprop_graph)));
(void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(func_graph)));
func_graph->set_flags(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
}
}
*data = func_graph;
return true;
}
bool ConvertOtherObj(py::object obj, ValuePtr *const data) { bool ConvertOtherObj(py::object obj, ValuePtr *const data) {
auto obj_type = data_converter::GetObjType(obj); auto obj_type = data_converter::GetObjType(obj);
MS_LOG(DEBUG) << "Converting the object(" << ((std::string)py::str(obj)) << ") detail type: " << obj_type << " "; MS_LOG(DEBUG) << "Converting the object(" << ((std::string)py::str(obj)) << ") detail type: " << obj_type << " ";
@ -262,32 +287,12 @@ bool ConvertOtherObj(py::object obj, ValuePtr *const data) {
// Create the namespace for common class instance // Create the namespace for common class instance
// When the obj is Cell, default parse the 'construct' // When the obj is Cell, default parse the 'construct'
if (data_converter::IsCellInstance(obj)) { if (data_converter::IsCellInstance(obj)) {
FuncGraphPtr func_graph = ConvertToFuncGraph(obj); return ConvertCellObjToFuncGraph(obj, data);
if (func_graph == nullptr) {
MS_LOG(ERROR) << "Parse resolve function error.";
return false;
} }
// if the cell object has specified bprop, it has user-defined bprop function parse and record it
if (py::hasattr(obj, "bprop")) {
FuncGraphPtr bprop_graph = nullptr;
bool enable_bprop_debug = py::cast<bool>(py::getattr(obj, "bprop_debug"));
if (enable_bprop_debug) {
bprop_graph = ConvertToBpropCut(obj);
} else {
bprop_graph = ConvertToFuncGraph(obj, PYTHON_MOD_GET_BPROP_METHOD);
}
if (bprop_graph != nullptr) {
(void)func_graph->transforms().insert(std::make_pair("bprop", FuncGraphTransform(bprop_graph)));
(void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(func_graph)));
func_graph->set_flags(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
}
}
*data = func_graph;
} else {
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
py::object namespace_var = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, obj); py::object namespace_var = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, obj);
*data = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var); *data = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var);
}
return true; return true;
} }
MS_LOG(ERROR) << "Resolve type is invalid " << ((std::string)py::str(obj)); MS_LOG(ERROR) << "Resolve type is invalid " << ((std::string)py::str(obj));

View File

@ -608,7 +608,7 @@ void Pipeline::Run() {
MS_LOG(INFO) << "End"; MS_LOG(INFO) << "End";
} }
void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef *arg_list) { void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef *const arg_list) {
std::size_t size = args.size(); std::size_t size = args.size();
for (std::size_t i = 0; i < size; i++) { for (std::size_t i = 0; i < size; i++) {

View File

@ -139,7 +139,7 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc
const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes, const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes,
const std::vector<int64_t> &input_indexes, bool need_run); const std::vector<int64_t> &input_indexes, bool need_run);
void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef *arg_list); void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef *const arg_list);
} // namespace pipeline } // namespace pipeline
} // namespace mindspore } // namespace mindspore

View File

@ -464,6 +464,85 @@ EvalResultPtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr>
return ExecuteMultipleEvaluators(evaluators, out_conf, args_conf_list); return ExecuteMultipleEvaluators(evaluators, out_conf, args_conf_list);
} }
void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator) {
auto fg_eval = evaluator->cast<FuncGraphEvaluatorPtr>();
if (fg_eval == nullptr) {
return;
}
auto fg = fg_eval->func_graph();
MS_EXCEPTION_IF_NULL(fg);
auto undetermined_fgs = fg->recursive_graphs();
if (undetermined_fgs) {
auto fg_parent = fg->parent();
MS_EXCEPTION_IF_NULL(fg_parent);
fg_parent->set_flags(kFuncGraphFlagUndetermined, true);
MS_LOG(DEBUG) << "Set graph undetermined: " << fg_parent->ToString();
}
}
EvaluatorPtr AnalysisEngine::HandleNestedRecursion(const std::vector<EvaluatorPtr> &evaluators,
const EvaluatorPtr &eval, const AbstractBasePtrList &args_spec_list,
const EvalTraceRevIter &it, bool *continue_flag) {
*continue_flag = false;
// Find latest entry function to handle nested recursion.
EvaluatorPtr latest_entry = eval;
auto latest_entry_iter = eval_trace_.rbegin();
for (auto r_it = eval_trace_.rbegin(); *r_it != *it;) {
auto it_temp = std::find(evaluators.begin(), evaluators.end(), r_it->first);
if (it_temp != evaluators.end()) {
latest_entry = *it_temp;
latest_entry_iter = r_it;
break;
}
latest_entry_iter = ++r_it;
}
if (latest_entry != eval) {
MS_LOG(DEBUG) << "Continue Evaluator " << eval->ToString();
*continue_flag = true;
return latest_entry;
}
bool has_undetermined = false;
// Check whether sub loop has untraced undetermined evaluator.
std::set<std::pair<EvaluatorPtr, AbstractBasePtrList>> undetermined_evals;
for (auto r_it = eval_trace_.rbegin(); r_it != latest_entry_iter; r_it++) {
undetermined_evals.insert(*r_it);
}
MS_LOG(DEBUG) << "undetermined_evals size(): " << undetermined_evals.size();
for (auto u_eval : undetermined_evals) {
MS_LOG(DEBUG) << u_eval.first->ToString() << " check undetermined.";
if (!undetermined_evals.count(std::make_pair(multi_poss_[u_eval.first], args_spec_list))) {
MS_LOG(DEBUG) << u_eval.first->ToString() << " has undetermined.";
has_undetermined = true;
break;
}
}
if (has_undetermined == false) {
MS_LOG(DEBUG) << eval->ToString() << " has no undetermined.";
*continue_flag = true;
return latest_entry;
}
return latest_entry;
}
EvalResultPtr AnalysisEngine::ProcessEvalResults(const AbstractBasePtrList &out_specs) {
if (out_specs.size() == 0) {
MS_LOG(EXCEPTION) << "There is an endless loop for evaluator.";
}
if (out_specs.size() == 1) {
MS_EXCEPTION_IF_NULL(out_specs[0]);
// If only one result derived, then broaden it to avoid wrong constant propagation.
return std::make_shared<EvalResult>(out_specs[0]->Broaden(), std::make_shared<AttrValueMap>());
}
auto joined_spec = AbstractJoin(out_specs);
MS_EXCEPTION_IF_NULL(joined_spec);
MS_LOG(DEBUG) << "Multiple evaluators joined: " << joined_spec->ToString();
return std::make_shared<EvalResult>(joined_spec, std::make_shared<AttrValueMap>());
}
EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators, EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators,
const AnfNodeConfigPtr &out_conf, const AnfNodeConfigPtr &out_conf,
const ConfigPtrList &args_conf_list) { const ConfigPtrList &args_conf_list) {
@ -479,18 +558,7 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Evalua
return conf->GetEvaluatedValue()->abstract(); return conf->GetEvaluatedValue()->abstract();
}); });
for (auto eval : evaluators) { for (auto eval : evaluators) {
auto fg_eval = eval->cast<FuncGraphEvaluatorPtr>(); SetUndeterminedFlag(eval);
if (fg_eval) {
auto fg = fg_eval->func_graph();
MS_EXCEPTION_IF_NULL(fg);
auto undetermined_fgs = fg->recursive_graphs();
if (undetermined_fgs) {
auto fg_parent = fg->parent();
MS_EXCEPTION_IF_NULL(fg_parent);
fg_parent->set_flags(kFuncGraphFlagUndetermined, true);
MS_LOG(DEBUG) << "Set graph undetermined: " << fg_parent->ToString();
}
}
auto current_inf = std::make_pair(eval, args_spec_list); auto current_inf = std::make_pair(eval, args_spec_list);
MS_LOG(DEBUG) << "Check Evaluator " << eval->ToString(); MS_LOG(DEBUG) << "Check Evaluator " << eval->ToString();
@ -510,40 +578,9 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Evalua
multi_poss_.clear(); multi_poss_.clear();
} }
} else if (it != eval_trace_.rbegin()) { } else if (it != eval_trace_.rbegin()) {
// Find latest entry function to handle nested recursion. bool continue_flag = false;
EvaluatorPtr latest_entry = eval; auto latest_entry = HandleNestedRecursion(evaluators, eval, args_spec_list, it, &continue_flag);
auto latest_entry_iter = eval_trace_.rbegin(); if (continue_flag) {
for (auto r_it = eval_trace_.rbegin(); *r_it != *it;) {
auto it_temp = std::find(evaluators.begin(), evaluators.end(), r_it->first);
if (it_temp != evaluators.end()) {
latest_entry = *it_temp;
latest_entry_iter = r_it;
break;
}
latest_entry_iter = ++r_it;
}
if (latest_entry != eval) {
MS_LOG(DEBUG) << "Continue Evaluator " << eval->ToString();
continue;
}
bool has_undetermined = false;
// Check whether sub loop has untraced undetermined evaluator.
std::set<std::pair<EvaluatorPtr, AbstractBasePtrList>> undetermined_evals;
for (auto r_it = eval_trace_.rbegin(); r_it != latest_entry_iter; r_it++) {
undetermined_evals.insert(*r_it);
}
MS_LOG(DEBUG) << "undetermined_evals size(): " << undetermined_evals.size();
for (auto u_eval : undetermined_evals) {
MS_LOG(DEBUG) << u_eval.first->ToString() << " check undetermined.";
if (!undetermined_evals.count(std::make_pair(multi_poss_[u_eval.first], args_spec_list))) {
MS_LOG(DEBUG) << u_eval.first->ToString() << " has undetermined.";
has_undetermined = true;
break;
}
}
if (has_undetermined == false) {
MS_LOG(DEBUG) << eval->ToString() << " has no undetermined.";
continue; continue;
} }
@ -558,19 +595,8 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Evalua
} }
} }
} }
if (out_specs.size() == 0) {
MS_LOG(EXCEPTION) << "There is an endless loop for evaluator.";
}
if (out_specs.size() == 1) { return ProcessEvalResults(out_specs);
MS_EXCEPTION_IF_NULL(out_specs[0]);
// If only one result derived, then broaden it to avoid wrong constant propagation.
return std::make_shared<EvalResult>(out_specs[0]->Broaden(), std::make_shared<AttrValueMap>());
}
auto joined_spec = AbstractJoin(out_specs);
MS_EXCEPTION_IF_NULL(joined_spec);
MS_LOG(DEBUG) << "Multiple evaluators joined: " << joined_spec->ToString();
return std::make_shared<EvalResult>(joined_spec, std::make_shared<AttrValueMap>());
} }
EvalResultPtr AnfNodeConfig::GetEvaluatedValue() { EvalResultPtr AnfNodeConfig::GetEvaluatedValue() {

View File

@ -172,6 +172,8 @@ struct AnalysisResult {
AnalysisContextPtr context; AnalysisContextPtr context;
}; };
using EvalTraceRevIter = std::list<std::pair<EvaluatorPtr, AbstractBasePtrList>>::reverse_iterator;
class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> { class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
public: public:
AnalysisEngine(const PrimEvaluatorMap &prim_evaluator_map, const FuncGraphManagerPtr &func_graph_manager) AnalysisEngine(const PrimEvaluatorMap &prim_evaluator_map, const FuncGraphManagerPtr &func_graph_manager)
@ -222,6 +224,12 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
std::unordered_map<PrimitivePyPtr, EvaluatorPtr> prim_py_evaluators_; std::unordered_map<PrimitivePyPtr, EvaluatorPtr> prim_py_evaluators_;
private: private:
void SetUndeterminedFlag(const EvaluatorPtr &evaluator);
EvaluatorPtr HandleNestedRecursion(const std::vector<EvaluatorPtr> &evaluators, const EvaluatorPtr &eval,
const AbstractBasePtrList &args_spec_list, const EvalTraceRevIter &it,
bool *continue_flag);
EvalResultPtr ProcessEvalResults(const AbstractBasePtrList &out_specs);
const PrimEvaluatorMap &prim_constructors_; const PrimEvaluatorMap &prim_constructors_;
FuncGraphManagerPtr func_graph_manager_; FuncGraphManagerPtr func_graph_manager_;
std::unordered_map<AbstractFunctionPtr, EvaluatorPtr> constructors_; std::unordered_map<AbstractFunctionPtr, EvaluatorPtr> constructors_;