Dump graph with type info when static analysis failed

This commit is contained in:
fary86 2020-03-28 17:53:22 +08:00
parent 930a1fb0a8
commit 816b60491d
6 changed files with 140 additions and 33 deletions

View File

@ -34,6 +34,7 @@
#include "utils/utils.h" #include "utils/utils.h"
#include "debug/trace.h" #include "debug/trace.h"
#include "utils/context/ms_context.h" #include "utils/context/ms_context.h"
#include "operator/ops.h"
namespace mindspore { namespace mindspore {
// max number of elements in sequence // max number of elements in sequence
@ -69,7 +70,7 @@ py::object load_obj(const std::string& path) {
// ============================================= MindSpore IR Exporter ============================================= // ============================================= MindSpore IR Exporter =============================================
std::string GetNodeType(const AnfNodePtr& nd) { std::string AnfExporter::GetNodeType(const AnfNodePtr& nd) {
abstract::ShapePtr shape = nd->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(nd->Shape()); abstract::ShapePtr shape = nd->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(nd->Shape());
TypePtr type = dyn_cast<Type>(nd->Type()); TypePtr type = dyn_cast<Type>(nd->Type());
std::ostringstream oss; std::ostringstream oss;
@ -102,7 +103,7 @@ int AnfExporter::GetParamIndex(const FuncGraphPtr& func_graph, const AnfNodePtr&
FuncGraphPtr fg = func_graph; FuncGraphPtr fg = func_graph;
while (fg != nullptr) { while (fg != nullptr) {
if (exported.find(fg) == exported.end()) { if (exported.find(fg) == exported.end()) {
if (!export_used_) { if (!check_integrity_) {
break; break;
} }
MS_LOG(EXCEPTION) << "Can not find func graph '" << fg->DumpText() << "." << fg->debug_info()->get_id() << "'"; MS_LOG(EXCEPTION) << "Can not find func graph '" << fg->DumpText() << "." << fg->debug_info()->get_id() << "'";
@ -255,15 +256,15 @@ std::string AnfExporter::GetPrimitiveText(const PrimitivePtr& prim) {
} }
// output primitive attributes // output primitive attributes
auto attrs = prim->attrs(); oss << prim->GetAttrsText();
if (attrs.size() > 0) {
oss << "["; if (prim->isa<prim::DoSignaturePrimitive>()) {
int i = 0; auto do_signature = dyn_cast<prim::DoSignaturePrimitive>(prim);
for (auto& attr : attrs) { auto& func = do_signature->function();
oss << (i > 0 ? ", " : "") << attr.first << "=" << attr.second->DumpText(); if (func->isa<Primitive>()) {
i++; auto sig_prim = dyn_cast<Primitive>(func);
oss << sig_prim->GetAttrsText();
} }
oss << "]";
} }
return oss.str(); return oss.str();
@ -351,7 +352,7 @@ std::string AnfExporter::GetDictText(const FuncGraphPtr& func_graph, const Value
std::string AnfExporter::GetOtherValueText(const FuncGraphPtr&, const ValuePtr& value) { std::string AnfExporter::GetOtherValueText(const FuncGraphPtr&, const ValuePtr& value) {
std::ostringstream oss; std::ostringstream oss;
if (export_used_) { if (check_integrity_) {
MS_LOG(EXCEPTION) << "Need to process type: " << value->type_name() << ", dump text: " << value->DumpText(); MS_LOG(EXCEPTION) << "Need to process type: " << value->type_name() << ", dump text: " << value->DumpText();
} }
oss << value->type_name() << "[" << value->DumpText() << "]"; oss << value->type_name() << "[" << value->DumpText() << "]";
@ -420,7 +421,7 @@ std::string AnfExporter::GetAnfNodeText(const FuncGraphPtr& func_graph, const An
} }
oss << "%" << iter->second; oss << "%" << iter->second;
} else if (node->isa<Parameter>()) { } else if (node->isa<Parameter>()) {
oss << "%para" << GetParamIndex(func_graph, node, export_used_); oss << "%para" << GetParamIndex(func_graph, node, check_integrity_);
} else if (IsValueNode<FuncGraph>(node)) { } else if (IsValueNode<FuncGraph>(node)) {
FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(node); FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(node);
oss << fg->type_name() << "::fg_" << fg->debug_info()->get_id(); oss << fg->type_name() << "::fg_" << fg->debug_info()->get_id();

View File

@ -64,17 +64,18 @@ struct ParamPtrHasher {
class AnfExporter { class AnfExporter {
public: public:
explicit AnfExporter(const std::string& id, bool export_used = true) explicit AnfExporter(const std::string& id, bool export_used = true, bool check_integrity = false)
: param_index(-1), id_(id), export_used_(export_used) { : param_index(-1), id_(id), export_used_(export_used), check_integrity_(check_integrity) {
func_graph_set.clear(); func_graph_set.clear();
exported.clear(); exported.clear();
} }
~AnfExporter() {} virtual ~AnfExporter() {}
void ExportFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph); void ExportFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph);
void ExportFuncGraph(const std::string& filename, const std::vector<TaggedGraph>& graphs); void ExportFuncGraph(const std::string& filename, const std::vector<TaggedGraph>& graphs);
private: protected:
virtual std::string GetNodeType(const AnfNodePtr& nd);
int GetParamIndex(const FuncGraphPtr& func_graph, const AnfNodePtr& param, bool throw_excp = true); int GetParamIndex(const FuncGraphPtr& func_graph, const AnfNodePtr& param, bool throw_excp = true);
int GetParamIndexFromExported(const AnfNodePtr& param); int GetParamIndexFromExported(const AnfNodePtr& param);
std::string DumpObject(const py::object& obj, const std::string& category) const; std::string DumpObject(const py::object& obj, const std::string& category) const;
@ -102,7 +103,9 @@ class AnfExporter {
OrderedMap<FuncGraphPtr, OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual>> exported; OrderedMap<FuncGraphPtr, OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual>> exported;
std::string id_; std::string id_;
bool export_used_ = true; // whether export function graphs used in current exporting function graph bool export_used_ = true; // whether export function graphs used in current exporting function graph
bool check_integrity_ = false; // whether check integrity or not, when dumping ir for loading, must set it to true
TaggedNodeMap tagged_cnodes_; TaggedNodeMap tagged_cnodes_;
abstract::AnfNodeConfigPtr node_cfg_ = nullptr;
}; };
void ExportIR(const std::string& filename, const std::string& id, const FuncGraphPtr& func_graph); void ExportIR(const std::string& filename, const std::string& id, const FuncGraphPtr& func_graph);
@ -115,7 +118,6 @@ std::string GetFuncGraphProtoString(const FuncGraphPtr& func_graph);
void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix); void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix);
std::string GetOnnxProtoString(const FuncGraphPtr& func_graph); std::string GetOnnxProtoString(const FuncGraphPtr& func_graph);
std::string GetNodeType(const AnfNodePtr& nd);
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_DEBUG_ANF_IR_UTILS_H_ #endif // MINDSPORE_CCSRC_DEBUG_ANF_IR_UTILS_H_

View File

@ -17,6 +17,7 @@
#include "debug/trace.h" #include "debug/trace.h"
#include <iostream> #include <iostream>
#include <fstream>
#include <map> #include <map>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
@ -194,37 +195,116 @@ void TraceGraphInfer() {
MS_LOG(INFO) << "\n*************************************************************************************"; MS_LOG(INFO) << "\n*************************************************************************************";
} }
void OutputAnalysisGraphInfo() { class AnalyzedFuncGraphExporter : public AnfExporter {
MS_LOG(INFO) << "Output analysis graph begin"; public:
std::unordered_map<FuncGraphPtr, size_t> index_map; AnalyzedFuncGraphExporter() : AnfExporter("", true, false) {}
std::vector<TaggedGraph> tagged_graphs; ~AnalyzedFuncGraphExporter() override = default;
void ExportFuncGraph(const std::string& filename, const std::vector<abstract::AnfNodeConfigPtr>& node_cfgs);
private:
std::string GetNodeType(const AnfNodePtr& nd) override;
};
std::unordered_map<FuncGraphPtr, TaggedNodeMap> CalcTaggedFuncGraphs() {
std::unordered_map<FuncGraphPtr, TaggedNodeMap> tagged_func_graphs;
auto& list = GetCNodeDebugStack(); auto& list = GetCNodeDebugStack();
for (size_t i = 0; i < list.size(); ++i) { for (size_t i = 0; i < list.size(); ++i) {
auto& node_cfg = list[i]; auto node_cfg = list[i];
auto fg = node_cfg->context()->func_graph(); auto fg = node_cfg->context()->func_graph();
auto node = node_cfg->node(); auto node = node_cfg->node();
auto idx = tagged_graphs.size(); tagged_func_graphs[fg][node] = i;
std::pair<FuncGraphPtr, size_t> item(fg, idx);
if (index_map.insert(item).second) {
tagged_graphs.emplace_back(TaggedGraph(fg, TaggedNodeMap()));
} }
tagged_graphs[index_map[fg]].second[node] = i; return tagged_func_graphs;
} }
ExportIR("analyze_fail.dat", tagged_graphs); void OutputAnalyzedGraphWithType() {
MS_LOG(INFO) << "Output analysis graph *end*"; AnalyzedFuncGraphExporter exporter;
exporter.ExportFuncGraph("analyze_fail.dat", GetCNodeDebugStack());
}
std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr& node) {
if (node_cfg_ == nullptr) {
return AnfExporter::GetNodeType(node);
}
auto ctx = node_cfg_->context();
auto engine = node_cfg_->engine();
auto cfg = engine->MakeConfig(node, ctx);
auto abs = engine->cache().GetValue(cfg);
if (abs == nullptr) {
return "Undefined";
}
auto dtype = abs->BuildType();
auto shape = abs->BuildShape();
std::ostringstream oss;
if (dtype != nullptr && abs->isa<abstract::AbstractTensor>() && shape != nullptr) {
oss << dtype->DumpText() << shape->DumpText();
} else if (dtype != nullptr) {
oss << dtype->DumpText();
} else {
oss << "Undefined";
}
return oss.str();
}
void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string& filename,
const std::vector<abstract::AnfNodeConfigPtr>& node_cfgs) {
if (node_cfgs.empty()) {
MS_LOG(DEBUG) << "Node configs is empty";
return;
}
std::ofstream ofs(filename);
if (!ofs.is_open()) {
MS_LOG(ERROR) << "Open file '" << filename << "' failed!";
return;
}
param_index = 1;
auto tagged_func_graphs = CalcTaggedFuncGraphs();
// first output grapn on the analysis stack
for (const auto& node_cfg : node_cfgs) {
auto fg = node_cfg->context()->func_graph();
// the graph is already output, skip it
if (exported.find(fg) != exported.end()) {
continue;
}
// set node_cfg info for getting type
node_cfg_ = node_cfg;
tagged_cnodes_ = tagged_func_graphs[fg];
ExportOneFuncGraph(ofs, fg);
ofs << "\n\n";
}
node_cfg_ = nullptr;
tagged_cnodes_.clear();
// print seperator between function graphs on analyzed graph call stack and others
ofs << "#===============================================================================\n\n\n";
// second output other graphs
while (!func_graph_set.empty()) {
FuncGraphPtr fg = *func_graph_set.begin();
ExportOneFuncGraph(ofs, fg);
ofs << "\n\n";
(void)func_graph_set.erase(fg);
}
ofs << "# num of total funcgraphs: " << exported.size();
ofs.close();
} }
void GetInferStackInfo(std::ostringstream& oss) { void GetInferStackInfo(std::ostringstream& oss) {
MS_LOG(INFO) << "Get graph analysis information begin"; MS_LOG(INFO) << "Get graph analysis information begin";
auto& stack = GetCNodeDebugStack(); auto stack = GetCNodeDebugStack();
if (stack.empty()) { if (stack.empty()) {
MS_LOG(INFO) << "Length of analysis information stack is empty."; MS_LOG(INFO) << "Length of analysis information stack is empty.";
return; return;
} }
OutputAnalysisGraphInfo(); OutputAnalyzedGraphWithType();
oss << "\nThe function call stack:\n"; oss << "\nThe function call stack:\n";
int index = 0; int index = 0;

View File

@ -106,6 +106,27 @@ void Primitive::set_signatures(
} }
} }
std::string Primitive::GetAttrsText() const {
if (attrs_.empty()) {
return "";
}
std::ostringstream oss;
oss << "[";
bool is_first = true;
for (auto& attr : attrs_) {
if (is_first) {
is_first = false;
} else {
oss << ", ";
}
oss << attr.first << "=" << attr.second->DumpText();
}
oss << "]";
return oss.str();
}
py::function PrimitivePy::GetBpropFunction() { py::function PrimitivePy::GetBpropFunction() {
static const char* const get_bprop_func_name = "get_bprop"; static const char* const get_bprop_func_name = "get_bprop";
if (py::hasattr(python_obj_, get_bprop_func_name)) { if (py::hasattr(python_obj_, get_bprop_func_name)) {

View File

@ -102,6 +102,7 @@ class Primitive : public Named {
PrimType prim_type() const { return prim_type_; } PrimType prim_type() const { return prim_type_; }
std::string instance_name() const { return instance_name_; } std::string instance_name() const { return instance_name_; }
std::string GetAttrsText() const;
bool operator==(const Value& other) const override; bool operator==(const Value& other) const override;
bool operator==(const Primitive& other) const; bool operator==(const Primitive& other) const;
~Primitive() override = default; ~Primitive() override = default;

View File

@ -22,6 +22,7 @@
#include "operator/ops.h" #include "operator/ops.h"
#include "pipeline/static_analysis/prim.h" #include "pipeline/static_analysis/prim.h"
#include "pipeline/static_analysis/abstract_function.h" #include "pipeline/static_analysis/abstract_function.h"
#include "debug/trace.h"
namespace mindspore { namespace mindspore {
using Shape = abstract::Shape; using Shape = abstract::Shape;
@ -124,6 +125,7 @@ TEST_F(TestComposite, test_TupleSlice_arg_one_number) {
AbstractBasePtrList args_spec_list = {tuple_tensor, start_index}; AbstractBasePtrList args_spec_list = {tuple_tensor, start_index};
try { try {
trace::ClearTraceStack();
engine_->Run(tupleSliceGraphPtr, args_spec_list); engine_->Run(tupleSliceGraphPtr, args_spec_list);
FAIL() << "Excepted exception :Args type is wrong"; FAIL() << "Excepted exception :Args type is wrong";
} catch (std::runtime_error const &err) { } catch (std::runtime_error const &err) {