!43 dump graph with type info when analysis failed

Merge pull request !43 from fary86/dump-typed-graph-when-analyze-fail
This commit is contained in:
mindspore-ci-bot 2020-03-31 22:23:02 +08:00 committed by Gitee
commit 52166a85cf
6 changed files with 140 additions and 33 deletions

View File

@ -34,6 +34,7 @@
#include "utils/utils.h"
#include "debug/trace.h"
#include "utils/context/ms_context.h"
#include "operator/ops.h"
namespace mindspore {
// max number of elements in sequence
@ -69,7 +70,7 @@ py::object load_obj(const std::string& path) {
// ============================================= MindSpore IR Exporter =============================================
std::string GetNodeType(const AnfNodePtr& nd) {
std::string AnfExporter::GetNodeType(const AnfNodePtr& nd) {
abstract::ShapePtr shape = nd->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(nd->Shape());
TypePtr type = dyn_cast<Type>(nd->Type());
std::ostringstream oss;
@ -102,7 +103,7 @@ int AnfExporter::GetParamIndex(const FuncGraphPtr& func_graph, const AnfNodePtr&
FuncGraphPtr fg = func_graph;
while (fg != nullptr) {
if (exported.find(fg) == exported.end()) {
if (!export_used_) {
if (!check_integrity_) {
break;
}
MS_LOG(EXCEPTION) << "Can not find func graph '" << fg->DumpText() << "." << fg->debug_info()->get_id() << "'";
@ -255,15 +256,15 @@ std::string AnfExporter::GetPrimitiveText(const PrimitivePtr& prim) {
}
// output primitive attributes
auto attrs = prim->attrs();
if (attrs.size() > 0) {
oss << "[";
int i = 0;
for (auto& attr : attrs) {
oss << (i > 0 ? ", " : "") << attr.first << "=" << attr.second->DumpText();
i++;
oss << prim->GetAttrsText();
if (prim->isa<prim::DoSignaturePrimitive>()) {
auto do_signature = dyn_cast<prim::DoSignaturePrimitive>(prim);
auto& func = do_signature->function();
if (func->isa<Primitive>()) {
auto sig_prim = dyn_cast<Primitive>(func);
oss << sig_prim->GetAttrsText();
}
oss << "]";
}
return oss.str();
@ -351,7 +352,7 @@ std::string AnfExporter::GetDictText(const FuncGraphPtr& func_graph, const Value
std::string AnfExporter::GetOtherValueText(const FuncGraphPtr&, const ValuePtr& value) {
std::ostringstream oss;
if (export_used_) {
if (check_integrity_) {
MS_LOG(EXCEPTION) << "Need to process type: " << value->type_name() << ", dump text: " << value->DumpText();
}
oss << value->type_name() << "[" << value->DumpText() << "]";
@ -420,7 +421,7 @@ std::string AnfExporter::GetAnfNodeText(const FuncGraphPtr& func_graph, const An
}
oss << "%" << iter->second;
} else if (node->isa<Parameter>()) {
oss << "%para" << GetParamIndex(func_graph, node, export_used_);
oss << "%para" << GetParamIndex(func_graph, node, check_integrity_);
} else if (IsValueNode<FuncGraph>(node)) {
FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(node);
oss << fg->type_name() << "::fg_" << fg->debug_info()->get_id();

View File

@ -64,17 +64,18 @@ struct ParamPtrHasher {
class AnfExporter {
public:
explicit AnfExporter(const std::string& id, bool export_used = true)
: param_index(-1), id_(id), export_used_(export_used) {
explicit AnfExporter(const std::string& id, bool export_used = true, bool check_integrity = false)
: param_index(-1), id_(id), export_used_(export_used), check_integrity_(check_integrity) {
func_graph_set.clear();
exported.clear();
}
~AnfExporter() {}
virtual ~AnfExporter() {}
void ExportFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph);
void ExportFuncGraph(const std::string& filename, const std::vector<TaggedGraph>& graphs);
private:
protected:
virtual std::string GetNodeType(const AnfNodePtr& nd);
int GetParamIndex(const FuncGraphPtr& func_graph, const AnfNodePtr& param, bool throw_excp = true);
int GetParamIndexFromExported(const AnfNodePtr& param);
std::string DumpObject(const py::object& obj, const std::string& category) const;
@ -101,8 +102,10 @@ class AnfExporter {
OrderedSet<FuncGraphPtr> func_graph_set{};
OrderedMap<FuncGraphPtr, OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual>> exported;
std::string id_;
bool export_used_ = true; // whether export function graphs used in current exporting function graph
bool export_used_ = true; // whether export function graphs used in current exporting function graph
bool check_integrity_ = false; // whether check integrity or not, when dumping ir for loading, must set it to true
TaggedNodeMap tagged_cnodes_;
abstract::AnfNodeConfigPtr node_cfg_ = nullptr;
};
void ExportIR(const std::string& filename, const std::string& id, const FuncGraphPtr& func_graph);
@ -115,7 +118,6 @@ std::string GetFuncGraphProtoString(const FuncGraphPtr& func_graph);
void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix);
std::string GetOnnxProtoString(const FuncGraphPtr& func_graph);
std::string GetNodeType(const AnfNodePtr& nd);
} // namespace mindspore
#endif // MINDSPORE_CCSRC_DEBUG_ANF_IR_UTILS_H_

View File

@ -17,6 +17,7 @@
#include "debug/trace.h"
#include <iostream>
#include <fstream>
#include <map>
#include <unordered_map>
#include <vector>
@ -194,37 +195,116 @@ void TraceGraphInfer() {
MS_LOG(INFO) << "\n*************************************************************************************";
}
void OutputAnalysisGraphInfo() {
MS_LOG(INFO) << "Output analysis graph begin";
std::unordered_map<FuncGraphPtr, size_t> index_map;
std::vector<TaggedGraph> tagged_graphs;
class AnalyzedFuncGraphExporter : public AnfExporter {
public:
AnalyzedFuncGraphExporter() : AnfExporter("", true, false) {}
~AnalyzedFuncGraphExporter() override = default;
void ExportFuncGraph(const std::string& filename, const std::vector<abstract::AnfNodeConfigPtr>& node_cfgs);
private:
std::string GetNodeType(const AnfNodePtr& nd) override;
};
std::unordered_map<FuncGraphPtr, TaggedNodeMap> CalcTaggedFuncGraphs() {
std::unordered_map<FuncGraphPtr, TaggedNodeMap> tagged_func_graphs;
auto& list = GetCNodeDebugStack();
for (size_t i = 0; i < list.size(); ++i) {
auto& node_cfg = list[i];
auto node_cfg = list[i];
auto fg = node_cfg->context()->func_graph();
auto node = node_cfg->node();
auto idx = tagged_graphs.size();
std::pair<FuncGraphPtr, size_t> item(fg, idx);
if (index_map.insert(item).second) {
tagged_graphs.emplace_back(TaggedGraph(fg, TaggedNodeMap()));
}
tagged_graphs[index_map[fg]].second[node] = i;
tagged_func_graphs[fg][node] = i;
}
return tagged_func_graphs;
}
void OutputAnalyzedGraphWithType() {
AnalyzedFuncGraphExporter exporter;
exporter.ExportFuncGraph("analyze_fail.dat", GetCNodeDebugStack());
}
std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr& node) {
if (node_cfg_ == nullptr) {
return AnfExporter::GetNodeType(node);
}
auto ctx = node_cfg_->context();
auto engine = node_cfg_->engine();
auto cfg = engine->MakeConfig(node, ctx);
auto abs = engine->cache().GetValue(cfg);
if (abs == nullptr) {
return "Undefined";
}
auto dtype = abs->BuildType();
auto shape = abs->BuildShape();
std::ostringstream oss;
if (dtype != nullptr && abs->isa<abstract::AbstractTensor>() && shape != nullptr) {
oss << dtype->DumpText() << shape->DumpText();
} else if (dtype != nullptr) {
oss << dtype->DumpText();
} else {
oss << "Undefined";
}
return oss.str();
}
void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string& filename,
const std::vector<abstract::AnfNodeConfigPtr>& node_cfgs) {
if (node_cfgs.empty()) {
MS_LOG(DEBUG) << "Node configs is empty";
return;
}
ExportIR("analyze_fail.dat", tagged_graphs);
MS_LOG(INFO) << "Output analysis graph *end*";
std::ofstream ofs(filename);
if (!ofs.is_open()) {
MS_LOG(ERROR) << "Open file '" << filename << "' failed!";
return;
}
param_index = 1;
auto tagged_func_graphs = CalcTaggedFuncGraphs();
// first output grapn on the analysis stack
for (const auto& node_cfg : node_cfgs) {
auto fg = node_cfg->context()->func_graph();
// the graph is already output, skip it
if (exported.find(fg) != exported.end()) {
continue;
}
// set node_cfg info for getting type
node_cfg_ = node_cfg;
tagged_cnodes_ = tagged_func_graphs[fg];
ExportOneFuncGraph(ofs, fg);
ofs << "\n\n";
}
node_cfg_ = nullptr;
tagged_cnodes_.clear();
// print seperator between function graphs on analyzed graph call stack and others
ofs << "#===============================================================================\n\n\n";
// second output other graphs
while (!func_graph_set.empty()) {
FuncGraphPtr fg = *func_graph_set.begin();
ExportOneFuncGraph(ofs, fg);
ofs << "\n\n";
(void)func_graph_set.erase(fg);
}
ofs << "# num of total funcgraphs: " << exported.size();
ofs.close();
}
void GetInferStackInfo(std::ostringstream& oss) {
MS_LOG(INFO) << "Get graph analysis information begin";
auto& stack = GetCNodeDebugStack();
auto stack = GetCNodeDebugStack();
if (stack.empty()) {
MS_LOG(INFO) << "Length of analysis information stack is empty.";
return;
}
OutputAnalysisGraphInfo();
OutputAnalyzedGraphWithType();
oss << "\nThe function call stack:\n";
int index = 0;

View File

@ -106,6 +106,27 @@ void Primitive::set_signatures(
}
}
std::string Primitive::GetAttrsText() const {
if (attrs_.empty()) {
return "";
}
std::ostringstream oss;
oss << "[";
bool is_first = true;
for (auto& attr : attrs_) {
if (is_first) {
is_first = false;
} else {
oss << ", ";
}
oss << attr.first << "=" << attr.second->DumpText();
}
oss << "]";
return oss.str();
}
py::function PrimitivePy::GetBpropFunction() {
static const char* const get_bprop_func_name = "get_bprop";
if (py::hasattr(python_obj_, get_bprop_func_name)) {

View File

@ -102,6 +102,7 @@ class Primitive : public Named {
PrimType prim_type() const { return prim_type_; }
std::string instance_name() const { return instance_name_; }
std::string GetAttrsText() const;
bool operator==(const Value& other) const override;
bool operator==(const Primitive& other) const;
~Primitive() override = default;

View File

@ -22,6 +22,7 @@
#include "operator/ops.h"
#include "pipeline/static_analysis/prim.h"
#include "pipeline/static_analysis/abstract_function.h"
#include "debug/trace.h"
namespace mindspore {
using Shape = abstract::Shape;
@ -124,6 +125,7 @@ TEST_F(TestComposite, test_TupleSlice_arg_one_number) {
AbstractBasePtrList args_spec_list = {tuple_tensor, start_index};
try {
trace::ClearTraceStack();
engine_->Run(tupleSliceGraphPtr, args_spec_list);
FAIL() << "Excepted exception :Args type is wrong";
} catch (std::runtime_error const &err) {