!43 dump graph with type info when analysis failed
Merge pull request !43 from fary86/dump-typed-graph-when-analyze-fail
This commit is contained in:
commit
52166a85cf
|
@ -34,6 +34,7 @@
|
|||
#include "utils/utils.h"
|
||||
#include "debug/trace.h"
|
||||
#include "utils/context/ms_context.h"
|
||||
#include "operator/ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
// max number of elements in sequence
|
||||
|
@ -69,7 +70,7 @@ py::object load_obj(const std::string& path) {
|
|||
|
||||
// ============================================= MindSpore IR Exporter =============================================
|
||||
|
||||
std::string GetNodeType(const AnfNodePtr& nd) {
|
||||
std::string AnfExporter::GetNodeType(const AnfNodePtr& nd) {
|
||||
abstract::ShapePtr shape = nd->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(nd->Shape());
|
||||
TypePtr type = dyn_cast<Type>(nd->Type());
|
||||
std::ostringstream oss;
|
||||
|
@ -102,7 +103,7 @@ int AnfExporter::GetParamIndex(const FuncGraphPtr& func_graph, const AnfNodePtr&
|
|||
FuncGraphPtr fg = func_graph;
|
||||
while (fg != nullptr) {
|
||||
if (exported.find(fg) == exported.end()) {
|
||||
if (!export_used_) {
|
||||
if (!check_integrity_) {
|
||||
break;
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Can not find func graph '" << fg->DumpText() << "." << fg->debug_info()->get_id() << "'";
|
||||
|
@ -255,15 +256,15 @@ std::string AnfExporter::GetPrimitiveText(const PrimitivePtr& prim) {
|
|||
}
|
||||
|
||||
// output primitive attributes
|
||||
auto attrs = prim->attrs();
|
||||
if (attrs.size() > 0) {
|
||||
oss << "[";
|
||||
int i = 0;
|
||||
for (auto& attr : attrs) {
|
||||
oss << (i > 0 ? ", " : "") << attr.first << "=" << attr.second->DumpText();
|
||||
i++;
|
||||
oss << prim->GetAttrsText();
|
||||
|
||||
if (prim->isa<prim::DoSignaturePrimitive>()) {
|
||||
auto do_signature = dyn_cast<prim::DoSignaturePrimitive>(prim);
|
||||
auto& func = do_signature->function();
|
||||
if (func->isa<Primitive>()) {
|
||||
auto sig_prim = dyn_cast<Primitive>(func);
|
||||
oss << sig_prim->GetAttrsText();
|
||||
}
|
||||
oss << "]";
|
||||
}
|
||||
|
||||
return oss.str();
|
||||
|
@ -351,7 +352,7 @@ std::string AnfExporter::GetDictText(const FuncGraphPtr& func_graph, const Value
|
|||
std::string AnfExporter::GetOtherValueText(const FuncGraphPtr&, const ValuePtr& value) {
|
||||
std::ostringstream oss;
|
||||
|
||||
if (export_used_) {
|
||||
if (check_integrity_) {
|
||||
MS_LOG(EXCEPTION) << "Need to process type: " << value->type_name() << ", dump text: " << value->DumpText();
|
||||
}
|
||||
oss << value->type_name() << "[" << value->DumpText() << "]";
|
||||
|
@ -420,7 +421,7 @@ std::string AnfExporter::GetAnfNodeText(const FuncGraphPtr& func_graph, const An
|
|||
}
|
||||
oss << "%" << iter->second;
|
||||
} else if (node->isa<Parameter>()) {
|
||||
oss << "%para" << GetParamIndex(func_graph, node, export_used_);
|
||||
oss << "%para" << GetParamIndex(func_graph, node, check_integrity_);
|
||||
} else if (IsValueNode<FuncGraph>(node)) {
|
||||
FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(node);
|
||||
oss << fg->type_name() << "::fg_" << fg->debug_info()->get_id();
|
||||
|
|
|
@ -64,17 +64,18 @@ struct ParamPtrHasher {
|
|||
|
||||
class AnfExporter {
|
||||
public:
|
||||
explicit AnfExporter(const std::string& id, bool export_used = true)
|
||||
: param_index(-1), id_(id), export_used_(export_used) {
|
||||
explicit AnfExporter(const std::string& id, bool export_used = true, bool check_integrity = false)
|
||||
: param_index(-1), id_(id), export_used_(export_used), check_integrity_(check_integrity) {
|
||||
func_graph_set.clear();
|
||||
exported.clear();
|
||||
}
|
||||
~AnfExporter() {}
|
||||
virtual ~AnfExporter() {}
|
||||
|
||||
void ExportFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph);
|
||||
void ExportFuncGraph(const std::string& filename, const std::vector<TaggedGraph>& graphs);
|
||||
|
||||
private:
|
||||
protected:
|
||||
virtual std::string GetNodeType(const AnfNodePtr& nd);
|
||||
int GetParamIndex(const FuncGraphPtr& func_graph, const AnfNodePtr& param, bool throw_excp = true);
|
||||
int GetParamIndexFromExported(const AnfNodePtr& param);
|
||||
std::string DumpObject(const py::object& obj, const std::string& category) const;
|
||||
|
@ -101,8 +102,10 @@ class AnfExporter {
|
|||
OrderedSet<FuncGraphPtr> func_graph_set{};
|
||||
OrderedMap<FuncGraphPtr, OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual>> exported;
|
||||
std::string id_;
|
||||
bool export_used_ = true; // whether export function graphs used in current exporting function graph
|
||||
bool export_used_ = true; // whether export function graphs used in current exporting function graph
|
||||
bool check_integrity_ = false; // whether check integrity or not, when dumping ir for loading, must set it to true
|
||||
TaggedNodeMap tagged_cnodes_;
|
||||
abstract::AnfNodeConfigPtr node_cfg_ = nullptr;
|
||||
};
|
||||
|
||||
void ExportIR(const std::string& filename, const std::string& id, const FuncGraphPtr& func_graph);
|
||||
|
@ -115,7 +118,6 @@ std::string GetFuncGraphProtoString(const FuncGraphPtr& func_graph);
|
|||
void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix);
|
||||
|
||||
std::string GetOnnxProtoString(const FuncGraphPtr& func_graph);
|
||||
std::string GetNodeType(const AnfNodePtr& nd);
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_DEBUG_ANF_IR_UTILS_H_
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "debug/trace.h"
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
@ -194,37 +195,116 @@ void TraceGraphInfer() {
|
|||
MS_LOG(INFO) << "\n*************************************************************************************";
|
||||
}
|
||||
|
||||
void OutputAnalysisGraphInfo() {
|
||||
MS_LOG(INFO) << "Output analysis graph begin";
|
||||
std::unordered_map<FuncGraphPtr, size_t> index_map;
|
||||
std::vector<TaggedGraph> tagged_graphs;
|
||||
class AnalyzedFuncGraphExporter : public AnfExporter {
|
||||
public:
|
||||
AnalyzedFuncGraphExporter() : AnfExporter("", true, false) {}
|
||||
~AnalyzedFuncGraphExporter() override = default;
|
||||
|
||||
void ExportFuncGraph(const std::string& filename, const std::vector<abstract::AnfNodeConfigPtr>& node_cfgs);
|
||||
|
||||
private:
|
||||
std::string GetNodeType(const AnfNodePtr& nd) override;
|
||||
};
|
||||
|
||||
std::unordered_map<FuncGraphPtr, TaggedNodeMap> CalcTaggedFuncGraphs() {
|
||||
std::unordered_map<FuncGraphPtr, TaggedNodeMap> tagged_func_graphs;
|
||||
auto& list = GetCNodeDebugStack();
|
||||
for (size_t i = 0; i < list.size(); ++i) {
|
||||
auto& node_cfg = list[i];
|
||||
auto node_cfg = list[i];
|
||||
auto fg = node_cfg->context()->func_graph();
|
||||
auto node = node_cfg->node();
|
||||
auto idx = tagged_graphs.size();
|
||||
std::pair<FuncGraphPtr, size_t> item(fg, idx);
|
||||
if (index_map.insert(item).second) {
|
||||
tagged_graphs.emplace_back(TaggedGraph(fg, TaggedNodeMap()));
|
||||
}
|
||||
tagged_graphs[index_map[fg]].second[node] = i;
|
||||
tagged_func_graphs[fg][node] = i;
|
||||
}
|
||||
return tagged_func_graphs;
|
||||
}
|
||||
|
||||
void OutputAnalyzedGraphWithType() {
|
||||
AnalyzedFuncGraphExporter exporter;
|
||||
exporter.ExportFuncGraph("analyze_fail.dat", GetCNodeDebugStack());
|
||||
}
|
||||
|
||||
std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr& node) {
|
||||
if (node_cfg_ == nullptr) {
|
||||
return AnfExporter::GetNodeType(node);
|
||||
}
|
||||
auto ctx = node_cfg_->context();
|
||||
auto engine = node_cfg_->engine();
|
||||
auto cfg = engine->MakeConfig(node, ctx);
|
||||
auto abs = engine->cache().GetValue(cfg);
|
||||
|
||||
if (abs == nullptr) {
|
||||
return "Undefined";
|
||||
}
|
||||
auto dtype = abs->BuildType();
|
||||
auto shape = abs->BuildShape();
|
||||
std::ostringstream oss;
|
||||
if (dtype != nullptr && abs->isa<abstract::AbstractTensor>() && shape != nullptr) {
|
||||
oss << dtype->DumpText() << shape->DumpText();
|
||||
} else if (dtype != nullptr) {
|
||||
oss << dtype->DumpText();
|
||||
} else {
|
||||
oss << "Undefined";
|
||||
}
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string& filename,
|
||||
const std::vector<abstract::AnfNodeConfigPtr>& node_cfgs) {
|
||||
if (node_cfgs.empty()) {
|
||||
MS_LOG(DEBUG) << "Node configs is empty";
|
||||
return;
|
||||
}
|
||||
|
||||
ExportIR("analyze_fail.dat", tagged_graphs);
|
||||
MS_LOG(INFO) << "Output analysis graph *end*";
|
||||
std::ofstream ofs(filename);
|
||||
if (!ofs.is_open()) {
|
||||
MS_LOG(ERROR) << "Open file '" << filename << "' failed!";
|
||||
return;
|
||||
}
|
||||
|
||||
param_index = 1;
|
||||
auto tagged_func_graphs = CalcTaggedFuncGraphs();
|
||||
|
||||
// first output grapn on the analysis stack
|
||||
for (const auto& node_cfg : node_cfgs) {
|
||||
auto fg = node_cfg->context()->func_graph();
|
||||
// the graph is already output, skip it
|
||||
if (exported.find(fg) != exported.end()) {
|
||||
continue;
|
||||
}
|
||||
// set node_cfg info for getting type
|
||||
node_cfg_ = node_cfg;
|
||||
tagged_cnodes_ = tagged_func_graphs[fg];
|
||||
ExportOneFuncGraph(ofs, fg);
|
||||
ofs << "\n\n";
|
||||
}
|
||||
|
||||
node_cfg_ = nullptr;
|
||||
tagged_cnodes_.clear();
|
||||
|
||||
// print seperator between function graphs on analyzed graph call stack and others
|
||||
ofs << "#===============================================================================\n\n\n";
|
||||
|
||||
// second output other graphs
|
||||
while (!func_graph_set.empty()) {
|
||||
FuncGraphPtr fg = *func_graph_set.begin();
|
||||
ExportOneFuncGraph(ofs, fg);
|
||||
ofs << "\n\n";
|
||||
(void)func_graph_set.erase(fg);
|
||||
}
|
||||
ofs << "# num of total funcgraphs: " << exported.size();
|
||||
|
||||
ofs.close();
|
||||
}
|
||||
|
||||
void GetInferStackInfo(std::ostringstream& oss) {
|
||||
MS_LOG(INFO) << "Get graph analysis information begin";
|
||||
auto& stack = GetCNodeDebugStack();
|
||||
auto stack = GetCNodeDebugStack();
|
||||
if (stack.empty()) {
|
||||
MS_LOG(INFO) << "Length of analysis information stack is empty.";
|
||||
return;
|
||||
}
|
||||
|
||||
OutputAnalysisGraphInfo();
|
||||
OutputAnalyzedGraphWithType();
|
||||
oss << "\nThe function call stack:\n";
|
||||
|
||||
int index = 0;
|
||||
|
|
|
@ -106,6 +106,27 @@ void Primitive::set_signatures(
|
|||
}
|
||||
}
|
||||
|
||||
std::string Primitive::GetAttrsText() const {
|
||||
if (attrs_.empty()) {
|
||||
return "";
|
||||
}
|
||||
|
||||
std::ostringstream oss;
|
||||
oss << "[";
|
||||
bool is_first = true;
|
||||
for (auto& attr : attrs_) {
|
||||
if (is_first) {
|
||||
is_first = false;
|
||||
} else {
|
||||
oss << ", ";
|
||||
}
|
||||
oss << attr.first << "=" << attr.second->DumpText();
|
||||
}
|
||||
oss << "]";
|
||||
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
py::function PrimitivePy::GetBpropFunction() {
|
||||
static const char* const get_bprop_func_name = "get_bprop";
|
||||
if (py::hasattr(python_obj_, get_bprop_func_name)) {
|
||||
|
|
|
@ -102,6 +102,7 @@ class Primitive : public Named {
|
|||
|
||||
PrimType prim_type() const { return prim_type_; }
|
||||
std::string instance_name() const { return instance_name_; }
|
||||
std::string GetAttrsText() const;
|
||||
bool operator==(const Value& other) const override;
|
||||
bool operator==(const Primitive& other) const;
|
||||
~Primitive() override = default;
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "operator/ops.h"
|
||||
#include "pipeline/static_analysis/prim.h"
|
||||
#include "pipeline/static_analysis/abstract_function.h"
|
||||
#include "debug/trace.h"
|
||||
|
||||
namespace mindspore {
|
||||
using Shape = abstract::Shape;
|
||||
|
@ -124,6 +125,7 @@ TEST_F(TestComposite, test_TupleSlice_arg_one_number) {
|
|||
AbstractBasePtrList args_spec_list = {tuple_tensor, start_index};
|
||||
|
||||
try {
|
||||
trace::ClearTraceStack();
|
||||
engine_->Run(tupleSliceGraphPtr, args_spec_list);
|
||||
FAIL() << "Excepted exception :Args type is wrong";
|
||||
} catch (std::runtime_error const &err) {
|
||||
|
|
Loading…
Reference in New Issue