!1822 fix codex split big functions
Merge pull request !1822 from fary86/codex_big_functions
This commit is contained in:
commit
20afadb4c0
|
@ -124,6 +124,8 @@ class AnalyzedFuncGraphExporter : public AnfExporter {
|
|||
|
||||
void ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &func_graph);
|
||||
void OutputCNodes(std::ofstream &ofs, const std::vector<AnfNodePtr> &nodes, const FuncGraphPtr &func_graph);
|
||||
void OutputCNode(std::ofstream &ofs, const CNodePtr &cnode, const FuncGraphPtr &func_graph, int *idx,
|
||||
std::map<AnfNodePtr, int> *const apply_map);
|
||||
|
||||
private:
|
||||
std::string GetNodeType(const AnfNodePtr &nd) override;
|
||||
|
@ -169,7 +171,7 @@ std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr &node) {
|
|||
}
|
||||
auto abs = ret->abstract();
|
||||
if (abs == nullptr) {
|
||||
return nullptr;
|
||||
return "Undefined";
|
||||
}
|
||||
auto dtype = abs->BuildType();
|
||||
auto shape = abs->BuildShape();
|
||||
|
@ -247,6 +249,51 @@ AnalysisContextPtr AnalyzedFuncGraphExporter::ProcessFuncGraphCall(const CNodePt
|
|||
return ctx;
|
||||
}
|
||||
|
||||
void AnalyzedFuncGraphExporter::OutputCNode(std::ofstream &ofs, const CNodePtr &cnode, const FuncGraphPtr &func_graph,
|
||||
int *idx, std::map<AnfNodePtr, int> *const apply_map) {
|
||||
auto &inputs = cnode->inputs();
|
||||
std::string op_text = GetAnfNodeText(func_graph, inputs[0], *apply_map);
|
||||
// non-return node
|
||||
if (cnode != func_graph->get_return()) {
|
||||
int apply_idx = (*idx)++;
|
||||
(*apply_map)[cnode] = apply_idx;
|
||||
std::string type_info = GetNodeType(cnode);
|
||||
if (type_info == "Undefined") {
|
||||
ofs << " %" << apply_idx << " = " << op_text << "(";
|
||||
} else {
|
||||
ofs << " %" << apply_idx << " : " << type_info << " = " << op_text << "(";
|
||||
}
|
||||
} else {
|
||||
ofs << " " << op_text << "(";
|
||||
}
|
||||
|
||||
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||
if (i != 1) {
|
||||
ofs << ", ";
|
||||
}
|
||||
AnfNodePtr arg = inputs[i];
|
||||
ofs << GetAnfNodeText(func_graph, arg, *apply_map);
|
||||
}
|
||||
ofs << ")";
|
||||
|
||||
// process function graph call
|
||||
auto ctx = ProcessFuncGraphCall(cnode);
|
||||
|
||||
// output comment
|
||||
OutputStatementComment(ofs, cnode);
|
||||
if (ctx != nullptr) {
|
||||
ofs << " @ctx.addr=" << ctx.get();
|
||||
}
|
||||
ofs << "\n";
|
||||
|
||||
if (label_manage::GetGlobalTraceLabelType() == label_manage::TraceLabelType::kWithUniqueId) {
|
||||
ofs << trace::GetDebugInfo(cnode->debug_info(), " # ", kSourceLineTipDiscard) << "#"
|
||||
<< label_manage::Label(cnode->debug_info()) << "\n";
|
||||
} else {
|
||||
ofs << trace::GetDebugInfo(cnode->debug_info(), " # ", kSourceLineTipDiscard) << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
void AnalyzedFuncGraphExporter::OutputCNodes(std::ofstream &ofs, const std::vector<AnfNodePtr> &nodes,
|
||||
const FuncGraphPtr &func_graph) {
|
||||
if (func_graph == nullptr) {
|
||||
|
@ -267,47 +314,7 @@ void AnalyzedFuncGraphExporter::OutputCNodes(std::ofstream &ofs, const std::vect
|
|||
}
|
||||
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto &inputs = cnode->inputs();
|
||||
std::string op_text = GetAnfNodeText(func_graph, inputs[0], apply_map);
|
||||
// non-return node
|
||||
if (node != func_graph->get_return()) {
|
||||
int apply_idx = idx++;
|
||||
apply_map[node] = apply_idx;
|
||||
std::string type_info = GetNodeType(node);
|
||||
if (type_info == "Undefined") {
|
||||
ofs << " %" << apply_idx << " = " << op_text << "(";
|
||||
} else {
|
||||
ofs << " %" << apply_idx << " : " << type_info << " = " << op_text << "(";
|
||||
}
|
||||
} else {
|
||||
ofs << " " << op_text << "(";
|
||||
}
|
||||
|
||||
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||
if (i != 1) {
|
||||
ofs << ", ";
|
||||
}
|
||||
AnfNodePtr arg = inputs[i];
|
||||
ofs << GetAnfNodeText(func_graph, arg, apply_map);
|
||||
}
|
||||
ofs << ")";
|
||||
|
||||
// process function graph call
|
||||
auto ctx = ProcessFuncGraphCall(cnode);
|
||||
|
||||
// output comment
|
||||
OutputStatementComment(ofs, cnode);
|
||||
if (ctx != nullptr) {
|
||||
ofs << " @ctx.addr=" << ctx.get();
|
||||
}
|
||||
ofs << "\n";
|
||||
|
||||
if (label_manage::GetGlobalTraceLabelType() == label_manage::TraceLabelType::kWithUniqueId) {
|
||||
ofs << trace::GetDebugInfo(cnode->debug_info(), " # ", kSourceLineTipDiscard) << "#"
|
||||
<< label_manage::Label(cnode->debug_info()) << "\n";
|
||||
} else {
|
||||
ofs << trace::GetDebugInfo(cnode->debug_info(), " # ", kSourceLineTipDiscard) << "\n";
|
||||
}
|
||||
OutputCNode(ofs, cnode, func_graph, &idx, &apply_map);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -76,25 +76,16 @@ bool CompareTensorScalarType(const TypeId &tensor_type, const size_t &t_type_num
|
|||
return true;
|
||||
}
|
||||
|
||||
void setMaxType(TypeId *max_type_id, TypeId *max_type, size_t *max_type_number, const TypeId type_id, const TypeId type,
|
||||
void SetMaxType(TypeId *max_type_id, TypeId *max_type, size_t *max_type_number, const TypeId type_id, const TypeId type,
|
||||
const size_t type_number) {
|
||||
*max_type_id = type_id;
|
||||
*max_type = type;
|
||||
*max_type_number = type_number;
|
||||
}
|
||||
|
||||
TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::vector<size_t> indexs,
|
||||
const std::set<size_t> &write_indexs) {
|
||||
TypeId max_type_id = kTypeUnknown;
|
||||
TypeId max_type = kTypeUnknown;
|
||||
size_t max_type_number = 0;
|
||||
bool has_int8 = false;
|
||||
for (const auto &index : indexs) {
|
||||
TypeId arg_type_id = kTypeUnknown;
|
||||
TypeId arg_type = kTypeUnknown;
|
||||
AbstractBasePtr arg_value = args_spec_list[index];
|
||||
bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId *arg_type_id,
|
||||
TypeId *arg_type = nullptr) {
|
||||
if (arg_value->isa<abstract::AbstractRef>()) {
|
||||
auto is_write = (write_indexs.find(index) != write_indexs.end());
|
||||
if (is_write) {
|
||||
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref_origin();
|
||||
} else {
|
||||
|
@ -105,15 +96,36 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve
|
|||
auto tensor = arg_value->cast<abstract::AbstractTensorPtr>();
|
||||
auto tensor_type = tensor->element()->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(tensor_type);
|
||||
arg_type_id = tensor_type->type_id();
|
||||
arg_type = kObjectTypeTensorType;
|
||||
} else if (arg_value->isa<abstract::AbstractScalar>()) {
|
||||
*arg_type_id = tensor_type->type_id();
|
||||
if (arg_type != nullptr) {
|
||||
*arg_type = kObjectTypeTensorType;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
if (arg_value->isa<abstract::AbstractScalar>()) {
|
||||
auto scalar = arg_value->cast<abstract::AbstractScalarPtr>();
|
||||
auto scalar_type = scalar->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(scalar_type);
|
||||
arg_type_id = scalar_type->type_id();
|
||||
arg_type = kObjectTypeNumber;
|
||||
} else {
|
||||
*arg_type_id = scalar_type->type_id();
|
||||
if (arg_type != nullptr) {
|
||||
*arg_type = kObjectTypeNumber;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::vector<size_t> indices,
|
||||
const std::set<size_t> &write_indices) {
|
||||
TypeId max_type_id = kTypeUnknown;
|
||||
TypeId max_type = kTypeUnknown;
|
||||
size_t max_type_number = 0;
|
||||
bool has_int8 = false;
|
||||
for (const auto &index : indices) {
|
||||
TypeId arg_type_id = kTypeUnknown;
|
||||
TypeId arg_type = kTypeUnknown;
|
||||
auto is_write = (write_indices.find(index) != write_indices.end());
|
||||
if (!GetTensorOrScalarTypeInfo(args_spec_list[index], is_write, &arg_type_id, &arg_type)) {
|
||||
continue;
|
||||
}
|
||||
auto it = type_map.find(arg_type_id);
|
||||
|
@ -124,22 +136,22 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve
|
|||
has_int8 = true;
|
||||
}
|
||||
if (max_type_id == kTypeUnknown) {
|
||||
setMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second);
|
||||
SetMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (max_type == arg_type) {
|
||||
if (it->second > max_type_number) {
|
||||
setMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second);
|
||||
SetMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second);
|
||||
}
|
||||
} else {
|
||||
if (arg_type == kObjectTypeTensorType) {
|
||||
if (CompareTensorScalarType(arg_type_id, it->second, max_type_id, max_type_number)) {
|
||||
setMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second);
|
||||
SetMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second);
|
||||
}
|
||||
} else {
|
||||
if (!CompareTensorScalarType(max_type_id, max_type_number, arg_type_id, it->second)) {
|
||||
setMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second);
|
||||
SetMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -154,28 +166,28 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve
|
|||
// Get the largest type of index in the same SignatureEnumDType of arguments.
|
||||
std::map<SignatureEnumDType, TypeId> GetMaxDtype(const std::vector<SignatureEnumDType> &dtypes,
|
||||
const abstract::AbstractBasePtrList &args_spec_list,
|
||||
const std::set<size_t> &write_indexs) {
|
||||
const std::set<size_t> &write_indices) {
|
||||
// record index for signature.dtypes of the same type
|
||||
// eg. [T, T1, T, T2, T, T1, T3] -> {{T:(0,2,4)}, {T1:(1,5)}, {T2:(3)}, {T3:(6)}}
|
||||
std::map<SignatureEnumDType, std::vector<size_t>> type_indexs;
|
||||
std::map<SignatureEnumDType, std::vector<size_t>> type_indices;
|
||||
for (size_t i = 0; i < dtypes.size(); ++i) {
|
||||
auto it = type_indexs.find(dtypes[i]);
|
||||
if (it == type_indexs.end()) {
|
||||
(void)type_indexs.insert(std::make_pair(dtypes[i], std::vector<size_t>{i}));
|
||||
auto it = type_indices.find(dtypes[i]);
|
||||
if (it == type_indices.end()) {
|
||||
(void)type_indices.insert(std::make_pair(dtypes[i], std::vector<size_t>{i}));
|
||||
} else {
|
||||
it->second.push_back(i);
|
||||
}
|
||||
}
|
||||
std::map<SignatureEnumDType, TypeId> dst_type;
|
||||
for (auto it = type_indexs.begin(); it != type_indexs.end(); (void)++it) {
|
||||
for (auto it = type_indices.begin(); it != type_indices.end(); (void)++it) {
|
||||
auto type = it->first;
|
||||
auto indexs = it->second;
|
||||
auto indices = it->second;
|
||||
// If the number of arguments belonging to the same SignatureEnumDType is less than 2, skip it.
|
||||
if (indexs.size() < 2) {
|
||||
if (indices.size() < 2) {
|
||||
continue;
|
||||
}
|
||||
bool has_tensor = false;
|
||||
for (const auto &index : indexs) {
|
||||
for (const auto &index : indices) {
|
||||
AbstractBasePtr arg_value = args_spec_list[index];
|
||||
if (arg_value->isa<abstract::AbstractRef>()) {
|
||||
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref();
|
||||
|
@ -189,7 +201,7 @@ std::map<SignatureEnumDType, TypeId> GetMaxDtype(const std::vector<SignatureEnum
|
|||
(void)dst_type.insert(std::make_pair(type, kTypeUnknown));
|
||||
continue;
|
||||
}
|
||||
(void)dst_type.insert(std::make_pair(type, GetMaxTypeId(args_spec_list, indexs, write_indexs)));
|
||||
(void)dst_type.insert(std::make_pair(type, GetMaxTypeId(args_spec_list, indices, write_indices)));
|
||||
}
|
||||
return dst_type;
|
||||
}
|
||||
|
@ -204,7 +216,7 @@ AnfNodePtr DoCast(const AnfNodePtr ¶m, const TypeId &type_id, const FuncGrap
|
|||
|
||||
void DoAutoCast(const std::string &func_name, const std::vector<Signature> &signature,
|
||||
const abstract::AbstractBasePtrList &args_spec_list, const FuncGraphPtr &graph,
|
||||
std::vector<AnfNodePtr> *const op_inputs, const std::set<size_t> &write_indexs) {
|
||||
std::vector<AnfNodePtr> *const op_inputs, const std::set<size_t> &write_indices) {
|
||||
std::vector<SignatureEnumDType> dtypes;
|
||||
(void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes),
|
||||
[](const Signature &sig) { return sig.dtype; });
|
||||
|
@ -213,36 +225,19 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign
|
|||
return;
|
||||
}
|
||||
// Stat the index of the arguments with the largest type in the same SignatureEnumDType.
|
||||
std::map<SignatureEnumDType, TypeId> dst_type = GetMaxDtype(dtypes, args_spec_list, write_indexs);
|
||||
std::map<SignatureEnumDType, TypeId> dst_type = GetMaxDtype(dtypes, args_spec_list, write_indices);
|
||||
// Identify which arg requires auto cast
|
||||
for (size_t i = 0; i < args_spec_list.size(); ++i) {
|
||||
auto it = dst_type.find(dtypes[i]);
|
||||
if (it == dst_type.end() || it->second == kTypeUnknown) {
|
||||
continue;
|
||||
}
|
||||
auto rw_it = write_indexs.find(i);
|
||||
auto is_write = (rw_it != write_indexs.end());
|
||||
auto rw_it = write_indices.find(i);
|
||||
auto is_write = (rw_it != write_indices.end());
|
||||
|
||||
AbstractBasePtr arg_value = args_spec_list[i];
|
||||
if (arg_value->isa<abstract::AbstractRef>()) {
|
||||
if (is_write) {
|
||||
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref_origin();
|
||||
} else {
|
||||
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref();
|
||||
}
|
||||
}
|
||||
TypeId arg_type_id = kTypeUnknown;
|
||||
if (arg_value->isa<abstract::AbstractTensor>()) {
|
||||
auto tensor = arg_value->cast<abstract::AbstractTensorPtr>();
|
||||
auto tensor_type = tensor->element()->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(tensor_type);
|
||||
arg_type_id = tensor_type->type_id();
|
||||
} else if (arg_value->isa<abstract::AbstractScalar>()) {
|
||||
auto scalar = arg_value->cast<abstract::AbstractScalarPtr>();
|
||||
auto scalar_type = scalar->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(scalar_type);
|
||||
arg_type_id = scalar_type->type_id();
|
||||
}
|
||||
AbstractBasePtr arg_value = args_spec_list[i];
|
||||
(void)GetTensorOrScalarTypeInfo(arg_value, is_write, &arg_type_id);
|
||||
auto it_map = type_map.find(arg_type_id);
|
||||
if (it_map == type_map.end()) {
|
||||
continue;
|
||||
|
@ -279,7 +274,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
|
|||
}
|
||||
}
|
||||
std::vector<AnfNodePtr> op_inputs;
|
||||
std::set<size_t> write_indexs;
|
||||
std::set<size_t> write_indices;
|
||||
op_inputs.push_back(NewValueNode(function));
|
||||
// Assume, the write input of op is always the first input. We check if any write op,
|
||||
// and add cast op on other inputs to keep the same type with assigned parameter.
|
||||
|
@ -303,7 +298,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
|
|||
param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefValue), param});
|
||||
} else if (sig == SignatureEnumRW::kRWWrite) {
|
||||
param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefOrigin), param});
|
||||
write_indexs.insert(i);
|
||||
write_indices.insert(i);
|
||||
}
|
||||
// If sig is SignatureEnumRW::kRWRef, not do anything.
|
||||
} else if (sig == SignatureEnumRW::kRWWrite && type->type_id() != kObjectTypeRefKey) {
|
||||
|
@ -313,7 +308,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
|
|||
}
|
||||
// process default
|
||||
ProcessDefault(func_name, args_spec_list, signature, has_var, &op_inputs);
|
||||
DoAutoCast(func_name, signature, args_spec_list, func_graph, &op_inputs, write_indexs);
|
||||
DoAutoCast(func_name, signature, args_spec_list, func_graph, &op_inputs, write_indices);
|
||||
return func_graph->NewCNode(op_inputs);
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -238,6 +238,31 @@ FuncGraphPtr ConvertToBpropCut(py::object obj) {
|
|||
return bprop_graph;
|
||||
}
|
||||
|
||||
bool ConvertCellObjToFuncGraph(py::object obj, ValuePtr *const data) {
|
||||
FuncGraphPtr func_graph = ConvertToFuncGraph(obj);
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Parse resolve function error.";
|
||||
return false;
|
||||
}
|
||||
// if the cell object has specified bprop, it has user-defined bprop function parse and record it
|
||||
if (py::hasattr(obj, "bprop")) {
|
||||
FuncGraphPtr bprop_graph = nullptr;
|
||||
bool enable_bprop_debug = py::cast<bool>(py::getattr(obj, "bprop_debug"));
|
||||
if (enable_bprop_debug) {
|
||||
bprop_graph = ConvertToBpropCut(obj);
|
||||
} else {
|
||||
bprop_graph = ConvertToFuncGraph(obj, PYTHON_MOD_GET_BPROP_METHOD);
|
||||
}
|
||||
if (bprop_graph != nullptr) {
|
||||
(void)func_graph->transforms().insert(std::make_pair("bprop", FuncGraphTransform(bprop_graph)));
|
||||
(void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(func_graph)));
|
||||
func_graph->set_flags(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
|
||||
}
|
||||
}
|
||||
*data = func_graph;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ConvertOtherObj(py::object obj, ValuePtr *const data) {
|
||||
auto obj_type = data_converter::GetObjType(obj);
|
||||
MS_LOG(DEBUG) << "Converting the object(" << ((std::string)py::str(obj)) << ") detail type: " << obj_type << " ";
|
||||
|
@ -262,32 +287,12 @@ bool ConvertOtherObj(py::object obj, ValuePtr *const data) {
|
|||
// Create the namespace for common class instance
|
||||
// When the obj is Cell, default parse the 'construct'
|
||||
if (data_converter::IsCellInstance(obj)) {
|
||||
FuncGraphPtr func_graph = ConvertToFuncGraph(obj);
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Parse resolve function error.";
|
||||
return false;
|
||||
return ConvertCellObjToFuncGraph(obj, data);
|
||||
}
|
||||
// if the cell object has specified bprop, it has user-defined bprop function parse and record it
|
||||
if (py::hasattr(obj, "bprop")) {
|
||||
FuncGraphPtr bprop_graph = nullptr;
|
||||
bool enable_bprop_debug = py::cast<bool>(py::getattr(obj, "bprop_debug"));
|
||||
if (enable_bprop_debug) {
|
||||
bprop_graph = ConvertToBpropCut(obj);
|
||||
} else {
|
||||
bprop_graph = ConvertToFuncGraph(obj, PYTHON_MOD_GET_BPROP_METHOD);
|
||||
}
|
||||
if (bprop_graph != nullptr) {
|
||||
(void)func_graph->transforms().insert(std::make_pair("bprop", FuncGraphTransform(bprop_graph)));
|
||||
(void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(func_graph)));
|
||||
func_graph->set_flags(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
|
||||
}
|
||||
}
|
||||
*data = func_graph;
|
||||
} else {
|
||||
|
||||
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
|
||||
py::object namespace_var = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, obj);
|
||||
*data = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
MS_LOG(ERROR) << "Resolve type is invalid " << ((std::string)py::str(obj));
|
||||
|
|
|
@ -608,7 +608,7 @@ void Pipeline::Run() {
|
|||
MS_LOG(INFO) << "End";
|
||||
}
|
||||
|
||||
void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef *arg_list) {
|
||||
void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef *const arg_list) {
|
||||
std::size_t size = args.size();
|
||||
|
||||
for (std::size_t i = 0; i < size; i++) {
|
||||
|
|
|
@ -139,7 +139,7 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc
|
|||
const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes,
|
||||
const std::vector<int64_t> &input_indexes, bool need_run);
|
||||
|
||||
void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef *arg_list);
|
||||
void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef *const arg_list);
|
||||
|
||||
} // namespace pipeline
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -464,6 +464,85 @@ EvalResultPtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr>
|
|||
return ExecuteMultipleEvaluators(evaluators, out_conf, args_conf_list);
|
||||
}
|
||||
|
||||
void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator) {
|
||||
auto fg_eval = evaluator->cast<FuncGraphEvaluatorPtr>();
|
||||
if (fg_eval == nullptr) {
|
||||
return;
|
||||
}
|
||||
auto fg = fg_eval->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
auto undetermined_fgs = fg->recursive_graphs();
|
||||
if (undetermined_fgs) {
|
||||
auto fg_parent = fg->parent();
|
||||
MS_EXCEPTION_IF_NULL(fg_parent);
|
||||
fg_parent->set_flags(kFuncGraphFlagUndetermined, true);
|
||||
MS_LOG(DEBUG) << "Set graph undetermined: " << fg_parent->ToString();
|
||||
}
|
||||
}
|
||||
|
||||
EvaluatorPtr AnalysisEngine::HandleNestedRecursion(const std::vector<EvaluatorPtr> &evaluators,
|
||||
const EvaluatorPtr &eval, const AbstractBasePtrList &args_spec_list,
|
||||
const EvalTraceRevIter &it, bool *continue_flag) {
|
||||
*continue_flag = false;
|
||||
// Find latest entry function to handle nested recursion.
|
||||
EvaluatorPtr latest_entry = eval;
|
||||
auto latest_entry_iter = eval_trace_.rbegin();
|
||||
for (auto r_it = eval_trace_.rbegin(); *r_it != *it;) {
|
||||
auto it_temp = std::find(evaluators.begin(), evaluators.end(), r_it->first);
|
||||
if (it_temp != evaluators.end()) {
|
||||
latest_entry = *it_temp;
|
||||
latest_entry_iter = r_it;
|
||||
break;
|
||||
}
|
||||
latest_entry_iter = ++r_it;
|
||||
}
|
||||
if (latest_entry != eval) {
|
||||
MS_LOG(DEBUG) << "Continue Evaluator " << eval->ToString();
|
||||
*continue_flag = true;
|
||||
return latest_entry;
|
||||
}
|
||||
|
||||
bool has_undetermined = false;
|
||||
// Check whether sub loop has untraced undetermined evaluator.
|
||||
std::set<std::pair<EvaluatorPtr, AbstractBasePtrList>> undetermined_evals;
|
||||
for (auto r_it = eval_trace_.rbegin(); r_it != latest_entry_iter; r_it++) {
|
||||
undetermined_evals.insert(*r_it);
|
||||
}
|
||||
MS_LOG(DEBUG) << "undetermined_evals size(): " << undetermined_evals.size();
|
||||
|
||||
for (auto u_eval : undetermined_evals) {
|
||||
MS_LOG(DEBUG) << u_eval.first->ToString() << " check undetermined.";
|
||||
if (!undetermined_evals.count(std::make_pair(multi_poss_[u_eval.first], args_spec_list))) {
|
||||
MS_LOG(DEBUG) << u_eval.first->ToString() << " has undetermined.";
|
||||
has_undetermined = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (has_undetermined == false) {
|
||||
MS_LOG(DEBUG) << eval->ToString() << " has no undetermined.";
|
||||
*continue_flag = true;
|
||||
return latest_entry;
|
||||
}
|
||||
|
||||
return latest_entry;
|
||||
}
|
||||
|
||||
EvalResultPtr AnalysisEngine::ProcessEvalResults(const AbstractBasePtrList &out_specs) {
|
||||
if (out_specs.size() == 0) {
|
||||
MS_LOG(EXCEPTION) << "There is an endless loop for evaluator.";
|
||||
}
|
||||
|
||||
if (out_specs.size() == 1) {
|
||||
MS_EXCEPTION_IF_NULL(out_specs[0]);
|
||||
// If only one result derived, then broaden it to avoid wrong constant propagation.
|
||||
return std::make_shared<EvalResult>(out_specs[0]->Broaden(), std::make_shared<AttrValueMap>());
|
||||
}
|
||||
auto joined_spec = AbstractJoin(out_specs);
|
||||
MS_EXCEPTION_IF_NULL(joined_spec);
|
||||
MS_LOG(DEBUG) << "Multiple evaluators joined: " << joined_spec->ToString();
|
||||
return std::make_shared<EvalResult>(joined_spec, std::make_shared<AttrValueMap>());
|
||||
}
|
||||
|
||||
EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators,
|
||||
const AnfNodeConfigPtr &out_conf,
|
||||
const ConfigPtrList &args_conf_list) {
|
||||
|
@ -479,18 +558,7 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Evalua
|
|||
return conf->GetEvaluatedValue()->abstract();
|
||||
});
|
||||
for (auto eval : evaluators) {
|
||||
auto fg_eval = eval->cast<FuncGraphEvaluatorPtr>();
|
||||
if (fg_eval) {
|
||||
auto fg = fg_eval->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
auto undetermined_fgs = fg->recursive_graphs();
|
||||
if (undetermined_fgs) {
|
||||
auto fg_parent = fg->parent();
|
||||
MS_EXCEPTION_IF_NULL(fg_parent);
|
||||
fg_parent->set_flags(kFuncGraphFlagUndetermined, true);
|
||||
MS_LOG(DEBUG) << "Set graph undetermined: " << fg_parent->ToString();
|
||||
}
|
||||
}
|
||||
SetUndeterminedFlag(eval);
|
||||
|
||||
auto current_inf = std::make_pair(eval, args_spec_list);
|
||||
MS_LOG(DEBUG) << "Check Evaluator " << eval->ToString();
|
||||
|
@ -510,40 +578,9 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Evalua
|
|||
multi_poss_.clear();
|
||||
}
|
||||
} else if (it != eval_trace_.rbegin()) {
|
||||
// Find latest entry function to handle nested recursion.
|
||||
EvaluatorPtr latest_entry = eval;
|
||||
auto latest_entry_iter = eval_trace_.rbegin();
|
||||
for (auto r_it = eval_trace_.rbegin(); *r_it != *it;) {
|
||||
auto it_temp = std::find(evaluators.begin(), evaluators.end(), r_it->first);
|
||||
if (it_temp != evaluators.end()) {
|
||||
latest_entry = *it_temp;
|
||||
latest_entry_iter = r_it;
|
||||
break;
|
||||
}
|
||||
latest_entry_iter = ++r_it;
|
||||
}
|
||||
if (latest_entry != eval) {
|
||||
MS_LOG(DEBUG) << "Continue Evaluator " << eval->ToString();
|
||||
continue;
|
||||
}
|
||||
|
||||
bool has_undetermined = false;
|
||||
// Check whether sub loop has untraced undetermined evaluator.
|
||||
std::set<std::pair<EvaluatorPtr, AbstractBasePtrList>> undetermined_evals;
|
||||
for (auto r_it = eval_trace_.rbegin(); r_it != latest_entry_iter; r_it++) {
|
||||
undetermined_evals.insert(*r_it);
|
||||
}
|
||||
MS_LOG(DEBUG) << "undetermined_evals size(): " << undetermined_evals.size();
|
||||
for (auto u_eval : undetermined_evals) {
|
||||
MS_LOG(DEBUG) << u_eval.first->ToString() << " check undetermined.";
|
||||
if (!undetermined_evals.count(std::make_pair(multi_poss_[u_eval.first], args_spec_list))) {
|
||||
MS_LOG(DEBUG) << u_eval.first->ToString() << " has undetermined.";
|
||||
has_undetermined = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (has_undetermined == false) {
|
||||
MS_LOG(DEBUG) << eval->ToString() << " has no undetermined.";
|
||||
bool continue_flag = false;
|
||||
auto latest_entry = HandleNestedRecursion(evaluators, eval, args_spec_list, it, &continue_flag);
|
||||
if (continue_flag) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -558,19 +595,8 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Evalua
|
|||
}
|
||||
}
|
||||
}
|
||||
if (out_specs.size() == 0) {
|
||||
MS_LOG(EXCEPTION) << "There is an endless loop for evaluator.";
|
||||
}
|
||||
|
||||
if (out_specs.size() == 1) {
|
||||
MS_EXCEPTION_IF_NULL(out_specs[0]);
|
||||
// If only one result derived, then broaden it to avoid wrong constant propagation.
|
||||
return std::make_shared<EvalResult>(out_specs[0]->Broaden(), std::make_shared<AttrValueMap>());
|
||||
}
|
||||
auto joined_spec = AbstractJoin(out_specs);
|
||||
MS_EXCEPTION_IF_NULL(joined_spec);
|
||||
MS_LOG(DEBUG) << "Multiple evaluators joined: " << joined_spec->ToString();
|
||||
return std::make_shared<EvalResult>(joined_spec, std::make_shared<AttrValueMap>());
|
||||
return ProcessEvalResults(out_specs);
|
||||
}
|
||||
|
||||
EvalResultPtr AnfNodeConfig::GetEvaluatedValue() {
|
||||
|
|
|
@ -172,6 +172,8 @@ struct AnalysisResult {
|
|||
AnalysisContextPtr context;
|
||||
};
|
||||
|
||||
using EvalTraceRevIter = std::list<std::pair<EvaluatorPtr, AbstractBasePtrList>>::reverse_iterator;
|
||||
|
||||
class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
|
||||
public:
|
||||
AnalysisEngine(const PrimEvaluatorMap &prim_evaluator_map, const FuncGraphManagerPtr &func_graph_manager)
|
||||
|
@ -222,6 +224,12 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
|
|||
std::unordered_map<PrimitivePyPtr, EvaluatorPtr> prim_py_evaluators_;
|
||||
|
||||
private:
|
||||
void SetUndeterminedFlag(const EvaluatorPtr &evaluator);
|
||||
EvaluatorPtr HandleNestedRecursion(const std::vector<EvaluatorPtr> &evaluators, const EvaluatorPtr &eval,
|
||||
const AbstractBasePtrList &args_spec_list, const EvalTraceRevIter &it,
|
||||
bool *continue_flag);
|
||||
EvalResultPtr ProcessEvalResults(const AbstractBasePtrList &out_specs);
|
||||
|
||||
const PrimEvaluatorMap &prim_constructors_;
|
||||
FuncGraphManagerPtr func_graph_manager_;
|
||||
std::unordered_map<AbstractFunctionPtr, EvaluatorPtr> constructors_;
|
||||
|
|
Loading…
Reference in New Issue