!43 dump graph with type info when analysis failed
Merge pull request !43 from fary86/dump-typed-graph-when-analyze-fail
This commit is contained in:
commit
52166a85cf
|
@ -34,6 +34,7 @@
|
||||||
#include "utils/utils.h"
|
#include "utils/utils.h"
|
||||||
#include "debug/trace.h"
|
#include "debug/trace.h"
|
||||||
#include "utils/context/ms_context.h"
|
#include "utils/context/ms_context.h"
|
||||||
|
#include "operator/ops.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
// max number of elements in sequence
|
// max number of elements in sequence
|
||||||
|
@ -69,7 +70,7 @@ py::object load_obj(const std::string& path) {
|
||||||
|
|
||||||
// ============================================= MindSpore IR Exporter =============================================
|
// ============================================= MindSpore IR Exporter =============================================
|
||||||
|
|
||||||
std::string GetNodeType(const AnfNodePtr& nd) {
|
std::string AnfExporter::GetNodeType(const AnfNodePtr& nd) {
|
||||||
abstract::ShapePtr shape = nd->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(nd->Shape());
|
abstract::ShapePtr shape = nd->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(nd->Shape());
|
||||||
TypePtr type = dyn_cast<Type>(nd->Type());
|
TypePtr type = dyn_cast<Type>(nd->Type());
|
||||||
std::ostringstream oss;
|
std::ostringstream oss;
|
||||||
|
@ -102,7 +103,7 @@ int AnfExporter::GetParamIndex(const FuncGraphPtr& func_graph, const AnfNodePtr&
|
||||||
FuncGraphPtr fg = func_graph;
|
FuncGraphPtr fg = func_graph;
|
||||||
while (fg != nullptr) {
|
while (fg != nullptr) {
|
||||||
if (exported.find(fg) == exported.end()) {
|
if (exported.find(fg) == exported.end()) {
|
||||||
if (!export_used_) {
|
if (!check_integrity_) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
MS_LOG(EXCEPTION) << "Can not find func graph '" << fg->DumpText() << "." << fg->debug_info()->get_id() << "'";
|
MS_LOG(EXCEPTION) << "Can not find func graph '" << fg->DumpText() << "." << fg->debug_info()->get_id() << "'";
|
||||||
|
@ -255,15 +256,15 @@ std::string AnfExporter::GetPrimitiveText(const PrimitivePtr& prim) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// output primitive attributes
|
// output primitive attributes
|
||||||
auto attrs = prim->attrs();
|
oss << prim->GetAttrsText();
|
||||||
if (attrs.size() > 0) {
|
|
||||||
oss << "[";
|
if (prim->isa<prim::DoSignaturePrimitive>()) {
|
||||||
int i = 0;
|
auto do_signature = dyn_cast<prim::DoSignaturePrimitive>(prim);
|
||||||
for (auto& attr : attrs) {
|
auto& func = do_signature->function();
|
||||||
oss << (i > 0 ? ", " : "") << attr.first << "=" << attr.second->DumpText();
|
if (func->isa<Primitive>()) {
|
||||||
i++;
|
auto sig_prim = dyn_cast<Primitive>(func);
|
||||||
|
oss << sig_prim->GetAttrsText();
|
||||||
}
|
}
|
||||||
oss << "]";
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return oss.str();
|
return oss.str();
|
||||||
|
@ -351,7 +352,7 @@ std::string AnfExporter::GetDictText(const FuncGraphPtr& func_graph, const Value
|
||||||
std::string AnfExporter::GetOtherValueText(const FuncGraphPtr&, const ValuePtr& value) {
|
std::string AnfExporter::GetOtherValueText(const FuncGraphPtr&, const ValuePtr& value) {
|
||||||
std::ostringstream oss;
|
std::ostringstream oss;
|
||||||
|
|
||||||
if (export_used_) {
|
if (check_integrity_) {
|
||||||
MS_LOG(EXCEPTION) << "Need to process type: " << value->type_name() << ", dump text: " << value->DumpText();
|
MS_LOG(EXCEPTION) << "Need to process type: " << value->type_name() << ", dump text: " << value->DumpText();
|
||||||
}
|
}
|
||||||
oss << value->type_name() << "[" << value->DumpText() << "]";
|
oss << value->type_name() << "[" << value->DumpText() << "]";
|
||||||
|
@ -420,7 +421,7 @@ std::string AnfExporter::GetAnfNodeText(const FuncGraphPtr& func_graph, const An
|
||||||
}
|
}
|
||||||
oss << "%" << iter->second;
|
oss << "%" << iter->second;
|
||||||
} else if (node->isa<Parameter>()) {
|
} else if (node->isa<Parameter>()) {
|
||||||
oss << "%para" << GetParamIndex(func_graph, node, export_used_);
|
oss << "%para" << GetParamIndex(func_graph, node, check_integrity_);
|
||||||
} else if (IsValueNode<FuncGraph>(node)) {
|
} else if (IsValueNode<FuncGraph>(node)) {
|
||||||
FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(node);
|
FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(node);
|
||||||
oss << fg->type_name() << "::fg_" << fg->debug_info()->get_id();
|
oss << fg->type_name() << "::fg_" << fg->debug_info()->get_id();
|
||||||
|
|
|
@ -64,17 +64,18 @@ struct ParamPtrHasher {
|
||||||
|
|
||||||
class AnfExporter {
|
class AnfExporter {
|
||||||
public:
|
public:
|
||||||
explicit AnfExporter(const std::string& id, bool export_used = true)
|
explicit AnfExporter(const std::string& id, bool export_used = true, bool check_integrity = false)
|
||||||
: param_index(-1), id_(id), export_used_(export_used) {
|
: param_index(-1), id_(id), export_used_(export_used), check_integrity_(check_integrity) {
|
||||||
func_graph_set.clear();
|
func_graph_set.clear();
|
||||||
exported.clear();
|
exported.clear();
|
||||||
}
|
}
|
||||||
~AnfExporter() {}
|
virtual ~AnfExporter() {}
|
||||||
|
|
||||||
void ExportFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph);
|
void ExportFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph);
|
||||||
void ExportFuncGraph(const std::string& filename, const std::vector<TaggedGraph>& graphs);
|
void ExportFuncGraph(const std::string& filename, const std::vector<TaggedGraph>& graphs);
|
||||||
|
|
||||||
private:
|
protected:
|
||||||
|
virtual std::string GetNodeType(const AnfNodePtr& nd);
|
||||||
int GetParamIndex(const FuncGraphPtr& func_graph, const AnfNodePtr& param, bool throw_excp = true);
|
int GetParamIndex(const FuncGraphPtr& func_graph, const AnfNodePtr& param, bool throw_excp = true);
|
||||||
int GetParamIndexFromExported(const AnfNodePtr& param);
|
int GetParamIndexFromExported(const AnfNodePtr& param);
|
||||||
std::string DumpObject(const py::object& obj, const std::string& category) const;
|
std::string DumpObject(const py::object& obj, const std::string& category) const;
|
||||||
|
@ -101,8 +102,10 @@ class AnfExporter {
|
||||||
OrderedSet<FuncGraphPtr> func_graph_set{};
|
OrderedSet<FuncGraphPtr> func_graph_set{};
|
||||||
OrderedMap<FuncGraphPtr, OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual>> exported;
|
OrderedMap<FuncGraphPtr, OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual>> exported;
|
||||||
std::string id_;
|
std::string id_;
|
||||||
bool export_used_ = true; // whether export function graphs used in current exporting function graph
|
bool export_used_ = true; // whether export function graphs used in current exporting function graph
|
||||||
|
bool check_integrity_ = false; // whether check integrity or not, when dumping ir for loading, must set it to true
|
||||||
TaggedNodeMap tagged_cnodes_;
|
TaggedNodeMap tagged_cnodes_;
|
||||||
|
abstract::AnfNodeConfigPtr node_cfg_ = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
void ExportIR(const std::string& filename, const std::string& id, const FuncGraphPtr& func_graph);
|
void ExportIR(const std::string& filename, const std::string& id, const FuncGraphPtr& func_graph);
|
||||||
|
@ -115,7 +118,6 @@ std::string GetFuncGraphProtoString(const FuncGraphPtr& func_graph);
|
||||||
void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix);
|
void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix);
|
||||||
|
|
||||||
std::string GetOnnxProtoString(const FuncGraphPtr& func_graph);
|
std::string GetOnnxProtoString(const FuncGraphPtr& func_graph);
|
||||||
std::string GetNodeType(const AnfNodePtr& nd);
|
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
#endif // MINDSPORE_CCSRC_DEBUG_ANF_IR_UTILS_H_
|
#endif // MINDSPORE_CCSRC_DEBUG_ANF_IR_UTILS_H_
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
#include "debug/trace.h"
|
#include "debug/trace.h"
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
#include <fstream>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
@ -194,37 +195,116 @@ void TraceGraphInfer() {
|
||||||
MS_LOG(INFO) << "\n*************************************************************************************";
|
MS_LOG(INFO) << "\n*************************************************************************************";
|
||||||
}
|
}
|
||||||
|
|
||||||
void OutputAnalysisGraphInfo() {
|
class AnalyzedFuncGraphExporter : public AnfExporter {
|
||||||
MS_LOG(INFO) << "Output analysis graph begin";
|
public:
|
||||||
std::unordered_map<FuncGraphPtr, size_t> index_map;
|
AnalyzedFuncGraphExporter() : AnfExporter("", true, false) {}
|
||||||
std::vector<TaggedGraph> tagged_graphs;
|
~AnalyzedFuncGraphExporter() override = default;
|
||||||
|
|
||||||
|
void ExportFuncGraph(const std::string& filename, const std::vector<abstract::AnfNodeConfigPtr>& node_cfgs);
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::string GetNodeType(const AnfNodePtr& nd) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::unordered_map<FuncGraphPtr, TaggedNodeMap> CalcTaggedFuncGraphs() {
|
||||||
|
std::unordered_map<FuncGraphPtr, TaggedNodeMap> tagged_func_graphs;
|
||||||
auto& list = GetCNodeDebugStack();
|
auto& list = GetCNodeDebugStack();
|
||||||
for (size_t i = 0; i < list.size(); ++i) {
|
for (size_t i = 0; i < list.size(); ++i) {
|
||||||
auto& node_cfg = list[i];
|
auto node_cfg = list[i];
|
||||||
auto fg = node_cfg->context()->func_graph();
|
auto fg = node_cfg->context()->func_graph();
|
||||||
auto node = node_cfg->node();
|
auto node = node_cfg->node();
|
||||||
auto idx = tagged_graphs.size();
|
tagged_func_graphs[fg][node] = i;
|
||||||
std::pair<FuncGraphPtr, size_t> item(fg, idx);
|
}
|
||||||
if (index_map.insert(item).second) {
|
return tagged_func_graphs;
|
||||||
tagged_graphs.emplace_back(TaggedGraph(fg, TaggedNodeMap()));
|
}
|
||||||
}
|
|
||||||
tagged_graphs[index_map[fg]].second[node] = i;
|
void OutputAnalyzedGraphWithType() {
|
||||||
|
AnalyzedFuncGraphExporter exporter;
|
||||||
|
exporter.ExportFuncGraph("analyze_fail.dat", GetCNodeDebugStack());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr& node) {
|
||||||
|
if (node_cfg_ == nullptr) {
|
||||||
|
return AnfExporter::GetNodeType(node);
|
||||||
|
}
|
||||||
|
auto ctx = node_cfg_->context();
|
||||||
|
auto engine = node_cfg_->engine();
|
||||||
|
auto cfg = engine->MakeConfig(node, ctx);
|
||||||
|
auto abs = engine->cache().GetValue(cfg);
|
||||||
|
|
||||||
|
if (abs == nullptr) {
|
||||||
|
return "Undefined";
|
||||||
|
}
|
||||||
|
auto dtype = abs->BuildType();
|
||||||
|
auto shape = abs->BuildShape();
|
||||||
|
std::ostringstream oss;
|
||||||
|
if (dtype != nullptr && abs->isa<abstract::AbstractTensor>() && shape != nullptr) {
|
||||||
|
oss << dtype->DumpText() << shape->DumpText();
|
||||||
|
} else if (dtype != nullptr) {
|
||||||
|
oss << dtype->DumpText();
|
||||||
|
} else {
|
||||||
|
oss << "Undefined";
|
||||||
|
}
|
||||||
|
return oss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string& filename,
|
||||||
|
const std::vector<abstract::AnfNodeConfigPtr>& node_cfgs) {
|
||||||
|
if (node_cfgs.empty()) {
|
||||||
|
MS_LOG(DEBUG) << "Node configs is empty";
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
ExportIR("analyze_fail.dat", tagged_graphs);
|
std::ofstream ofs(filename);
|
||||||
MS_LOG(INFO) << "Output analysis graph *end*";
|
if (!ofs.is_open()) {
|
||||||
|
MS_LOG(ERROR) << "Open file '" << filename << "' failed!";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
param_index = 1;
|
||||||
|
auto tagged_func_graphs = CalcTaggedFuncGraphs();
|
||||||
|
|
||||||
|
// first output grapn on the analysis stack
|
||||||
|
for (const auto& node_cfg : node_cfgs) {
|
||||||
|
auto fg = node_cfg->context()->func_graph();
|
||||||
|
// the graph is already output, skip it
|
||||||
|
if (exported.find(fg) != exported.end()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// set node_cfg info for getting type
|
||||||
|
node_cfg_ = node_cfg;
|
||||||
|
tagged_cnodes_ = tagged_func_graphs[fg];
|
||||||
|
ExportOneFuncGraph(ofs, fg);
|
||||||
|
ofs << "\n\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
node_cfg_ = nullptr;
|
||||||
|
tagged_cnodes_.clear();
|
||||||
|
|
||||||
|
// print seperator between function graphs on analyzed graph call stack and others
|
||||||
|
ofs << "#===============================================================================\n\n\n";
|
||||||
|
|
||||||
|
// second output other graphs
|
||||||
|
while (!func_graph_set.empty()) {
|
||||||
|
FuncGraphPtr fg = *func_graph_set.begin();
|
||||||
|
ExportOneFuncGraph(ofs, fg);
|
||||||
|
ofs << "\n\n";
|
||||||
|
(void)func_graph_set.erase(fg);
|
||||||
|
}
|
||||||
|
ofs << "# num of total funcgraphs: " << exported.size();
|
||||||
|
|
||||||
|
ofs.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
void GetInferStackInfo(std::ostringstream& oss) {
|
void GetInferStackInfo(std::ostringstream& oss) {
|
||||||
MS_LOG(INFO) << "Get graph analysis information begin";
|
MS_LOG(INFO) << "Get graph analysis information begin";
|
||||||
auto& stack = GetCNodeDebugStack();
|
auto stack = GetCNodeDebugStack();
|
||||||
if (stack.empty()) {
|
if (stack.empty()) {
|
||||||
MS_LOG(INFO) << "Length of analysis information stack is empty.";
|
MS_LOG(INFO) << "Length of analysis information stack is empty.";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
OutputAnalysisGraphInfo();
|
OutputAnalyzedGraphWithType();
|
||||||
oss << "\nThe function call stack:\n";
|
oss << "\nThe function call stack:\n";
|
||||||
|
|
||||||
int index = 0;
|
int index = 0;
|
||||||
|
|
|
@ -106,6 +106,27 @@ void Primitive::set_signatures(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string Primitive::GetAttrsText() const {
|
||||||
|
if (attrs_.empty()) {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
|
std::ostringstream oss;
|
||||||
|
oss << "[";
|
||||||
|
bool is_first = true;
|
||||||
|
for (auto& attr : attrs_) {
|
||||||
|
if (is_first) {
|
||||||
|
is_first = false;
|
||||||
|
} else {
|
||||||
|
oss << ", ";
|
||||||
|
}
|
||||||
|
oss << attr.first << "=" << attr.second->DumpText();
|
||||||
|
}
|
||||||
|
oss << "]";
|
||||||
|
|
||||||
|
return oss.str();
|
||||||
|
}
|
||||||
|
|
||||||
py::function PrimitivePy::GetBpropFunction() {
|
py::function PrimitivePy::GetBpropFunction() {
|
||||||
static const char* const get_bprop_func_name = "get_bprop";
|
static const char* const get_bprop_func_name = "get_bprop";
|
||||||
if (py::hasattr(python_obj_, get_bprop_func_name)) {
|
if (py::hasattr(python_obj_, get_bprop_func_name)) {
|
||||||
|
|
|
@ -102,6 +102,7 @@ class Primitive : public Named {
|
||||||
|
|
||||||
PrimType prim_type() const { return prim_type_; }
|
PrimType prim_type() const { return prim_type_; }
|
||||||
std::string instance_name() const { return instance_name_; }
|
std::string instance_name() const { return instance_name_; }
|
||||||
|
std::string GetAttrsText() const;
|
||||||
bool operator==(const Value& other) const override;
|
bool operator==(const Value& other) const override;
|
||||||
bool operator==(const Primitive& other) const;
|
bool operator==(const Primitive& other) const;
|
||||||
~Primitive() override = default;
|
~Primitive() override = default;
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
#include "operator/ops.h"
|
#include "operator/ops.h"
|
||||||
#include "pipeline/static_analysis/prim.h"
|
#include "pipeline/static_analysis/prim.h"
|
||||||
#include "pipeline/static_analysis/abstract_function.h"
|
#include "pipeline/static_analysis/abstract_function.h"
|
||||||
|
#include "debug/trace.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
using Shape = abstract::Shape;
|
using Shape = abstract::Shape;
|
||||||
|
@ -124,6 +125,7 @@ TEST_F(TestComposite, test_TupleSlice_arg_one_number) {
|
||||||
AbstractBasePtrList args_spec_list = {tuple_tensor, start_index};
|
AbstractBasePtrList args_spec_list = {tuple_tensor, start_index};
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
trace::ClearTraceStack();
|
||||||
engine_->Run(tupleSliceGraphPtr, args_spec_list);
|
engine_->Run(tupleSliceGraphPtr, args_spec_list);
|
||||||
FAIL() << "Excepted exception :Args type is wrong";
|
FAIL() << "Excepted exception :Args type is wrong";
|
||||||
} catch (std::runtime_error const &err) {
|
} catch (std::runtime_error const &err) {
|
||||||
|
|
Loading…
Reference in New Issue