forked from mindspore-Ecosystem/mindspore
optimize rdr mainly to reduce the redundant code
1. move common function to base recorder 2. add UpdataRdrEnable() for recorder manager, and update once 3. add delimiter virable for base recorder, set it as '.' 4. modify recorders to use the suffix parameter of GetFileRealPath() 5. improve the text in ms_context Can not add RdrEnablePlatform because linux is supported only. If os is linux, (!is_rdr_supported && rdr_enable) will be false always, else, (is_rdr_supported && !rdr_enable) will be false always.
This commit is contained in:
parent
692d158f5c
commit
7914363d25
|
@ -104,6 +104,8 @@ void EnvConfigParser::ParseRdrSetting(const nlohmann::json &content) {
|
|||
return;
|
||||
}
|
||||
|
||||
has_rdr_setting_ = true;
|
||||
|
||||
auto rdr_enable = CheckJsonKeyExist(*rdr_setting, kRdrSettings, kEnable);
|
||||
if (rdr_enable.has_value()) {
|
||||
ParseRdrEnable(**rdr_enable);
|
||||
|
|
|
@ -27,11 +27,14 @@ class EnvConfigParser {
|
|||
public:
|
||||
static EnvConfigParser &GetInstance() {
|
||||
static EnvConfigParser instance;
|
||||
instance.Parse();
|
||||
return instance;
|
||||
}
|
||||
|
||||
void Parse();
|
||||
std::string config_path() const { return config_file_; }
|
||||
|
||||
bool has_rdr_setting() const { return has_rdr_setting_; }
|
||||
bool rdr_enabled() const { return rdr_enabled_; }
|
||||
std::string rdr_path() const { return rdr_path_; }
|
||||
|
||||
|
@ -42,7 +45,9 @@ class EnvConfigParser {
|
|||
std::mutex lock_;
|
||||
std::string config_file_{""};
|
||||
bool already_parsed_{false};
|
||||
|
||||
bool rdr_enabled_{false};
|
||||
bool has_rdr_setting_{false};
|
||||
std::string rdr_path_{"./rdr/"};
|
||||
|
||||
std::string GetIfstreamString(const std::ifstream &ifstream);
|
||||
|
|
|
@ -39,13 +39,13 @@ void BaseRecorder::SetFilename(const std::string &filename) {
|
|||
|
||||
std::optional<std::string> BaseRecorder::GetFileRealPath(const std::string &suffix) {
|
||||
if (filename_.empty()) {
|
||||
filename_ = module_ + "_" + tag_;
|
||||
filename_ = module_ + delimiter_ + tag_;
|
||||
if (!suffix.empty()) {
|
||||
filename_ += "_" + suffix;
|
||||
filename_ += delimiter_ + suffix;
|
||||
}
|
||||
filename_ += "_" + timestamp_;
|
||||
filename_ += delimiter_ + timestamp_;
|
||||
} else if (!suffix.empty()) {
|
||||
filename_ += "_" + suffix;
|
||||
filename_ += delimiter_ + suffix;
|
||||
}
|
||||
std::string file_path = directory_ + filename_;
|
||||
auto realpath = Common::GetRealPath(file_path);
|
||||
|
|
|
@ -65,6 +65,7 @@ class BaseRecorder {
|
|||
|
||||
void SetDirectory(const std::string &directory);
|
||||
void SetFilename(const std::string &filename);
|
||||
void SetModule(const std::string &module) { module_ = module; }
|
||||
virtual void Export() {}
|
||||
|
||||
protected:
|
||||
|
@ -73,6 +74,7 @@ class BaseRecorder {
|
|||
std::string directory_;
|
||||
std::string filename_;
|
||||
std::string timestamp_; // year,month,day,hour,minute,second
|
||||
std::string delimiter_{"."};
|
||||
};
|
||||
using BaseRecorderPtr = std::shared_ptr<BaseRecorder>;
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -46,11 +46,11 @@ bool DumpGraphExeOrder(const std::string &filename, const std::vector<CNodePtr>
|
|||
} // namespace
|
||||
|
||||
void GraphExecOrderRecorder::Export() {
|
||||
auto realpath = GetFileRealPath();
|
||||
auto realpath = GetFileRealPath(std::to_string(graph_id_));
|
||||
if (!realpath.has_value()) {
|
||||
return;
|
||||
}
|
||||
std::string real_file_path = realpath.value() + std::to_string(graph_id_);
|
||||
std::string real_file_path = realpath.value() + ".txt";
|
||||
DumpGraphExeOrder(real_file_path, exec_order_);
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -30,7 +30,6 @@ class GraphExecOrderRecorder : public BaseRecorder {
|
|||
GraphExecOrderRecorder(const std::string &module, const std::string &tag,
|
||||
const std::vector<CNodePtr> &final_exec_order, int graph_id)
|
||||
: BaseRecorder(module, tag), exec_order_(final_exec_order), graph_id_(graph_id) {}
|
||||
void SetModule(const std::string &module) { module_ = module; }
|
||||
void SetExecOrder(const std::vector<CNodePtr> &final_exec_order) { exec_order_ = final_exec_order; }
|
||||
virtual void Export();
|
||||
|
||||
|
|
|
@ -65,6 +65,7 @@ void GraphRecorder::Export() {
|
|||
}
|
||||
std::string suffix = graph_id >= 0 ? std::to_string(graph_id) : "";
|
||||
auto tmp_realpath = GetFileRealPath(suffix);
|
||||
|
||||
if (!tmp_realpath.has_value()) {
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -31,7 +31,6 @@ class GraphRecorder : public BaseRecorder {
|
|||
const std::string &file_type)
|
||||
: BaseRecorder(module, tag), func_graph_(graph), graph_type_(file_type) {}
|
||||
~GraphRecorder() {}
|
||||
void SetModule(const std::string &module) { module_ = module; }
|
||||
void SetGraphType(const std::string &file_type) { graph_type_ = file_type; }
|
||||
void SetFuncGraph(const FuncGraphPtr &func_graph) { func_graph_ = func_graph; }
|
||||
void SetDumpFlag(bool full_name) { full_name_ = full_name; }
|
||||
|
|
|
@ -21,7 +21,35 @@
|
|||
#include "mindspore/core/ir/func_graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
||||
void RecorderManager::UpdateRdrEnable() {
|
||||
static bool updated = false;
|
||||
if (updated) {
|
||||
return;
|
||||
}
|
||||
auto &config_parser = mindspore::EnvConfigParser::GetInstance();
|
||||
rdr_enable_ = config_parser.rdr_enabled();
|
||||
if (config_parser.has_rdr_setting()) {
|
||||
#ifdef __linux__
|
||||
if (!rdr_enable_) {
|
||||
MS_LOG(WARNING) << "Please set the 'enable' as true using 'rdr' setting in file '" << config_parser.config_path()
|
||||
<< "' if you want to use RDR.";
|
||||
}
|
||||
#else
|
||||
if (rdr_enable_) {
|
||||
MS_LOG(WARNING) << "The RDR only supports linux os currently.";
|
||||
}
|
||||
rdr_enable_ = false;
|
||||
#endif
|
||||
}
|
||||
updated = true;
|
||||
}
|
||||
|
||||
bool RecorderManager::RecordObject(const BaseRecorderPtr &recorder) {
|
||||
if (!rdr_enable_) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (recorder == nullptr) {
|
||||
MS_LOG(ERROR) << "register recorder module with nullptr.";
|
||||
return false;
|
||||
|
@ -33,10 +61,7 @@ bool RecorderManager::RecordObject(const BaseRecorderPtr &recorder) {
|
|||
}
|
||||
|
||||
void RecorderManager::TriggerAll() {
|
||||
auto &config_parser_ptr = mindspore::EnvConfigParser::GetInstance();
|
||||
config_parser_ptr.Parse();
|
||||
if (!config_parser_ptr.rdr_enabled()) {
|
||||
MS_LOG(INFO) << "RDR is not enable.";
|
||||
if (!rdr_enable_) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
@ -31,9 +31,12 @@ class RecorderManager {
|
|||
public:
|
||||
static RecorderManager &Instance() {
|
||||
static RecorderManager manager;
|
||||
manager.UpdateRdrEnable();
|
||||
return manager;
|
||||
}
|
||||
|
||||
void UpdateRdrEnable();
|
||||
bool RdrEnable() const { return rdr_enable_; }
|
||||
bool RecordObject(const BaseRecorderPtr &recorder);
|
||||
void TriggerAll();
|
||||
void ClearAll();
|
||||
|
@ -42,6 +45,8 @@ class RecorderManager {
|
|||
RecorderManager() {}
|
||||
~RecorderManager() {}
|
||||
|
||||
bool rdr_enable_{false};
|
||||
|
||||
mutable std::mutex mtx_;
|
||||
// module, BaserRecorderPtrList
|
||||
std::unordered_map<std::string, BaseRecorderPtrList> recorder_container_;
|
||||
|
|
|
@ -65,6 +65,9 @@ namespace RDR {
|
|||
#ifdef ENABLE_D
|
||||
bool RecordTaskDebugInfo(SubModuleId module, const std::string &tag,
|
||||
const std::vector<TaskDebugInfoPtr> &task_debug_info_list, int graph_id) {
|
||||
if (!mindspore::RecorderManager::Instance().RdrEnable()) {
|
||||
return false;
|
||||
}
|
||||
std::string submodule_name = std::string(GetSubModuleName(module));
|
||||
TaskDebugInfoRecorderPtr task_debug_info_recorder =
|
||||
std::make_shared<TaskDebugInfoRecorder>(submodule_name, tag, task_debug_info_list, graph_id);
|
||||
|
@ -73,9 +76,11 @@ bool RecordTaskDebugInfo(SubModuleId module, const std::string &tag,
|
|||
}
|
||||
#endif // ENABLE_D
|
||||
|
||||
#ifdef __linux__
|
||||
bool RecordAnfGraph(const SubModuleId module, const std::string &tag, const FuncGraphPtr &graph, bool full_name,
|
||||
const std::string &file_type) {
|
||||
if (!mindspore::RecorderManager::Instance().RdrEnable()) {
|
||||
return false;
|
||||
}
|
||||
std::string submodule_name = std::string(GetSubModuleName(module));
|
||||
GraphRecorderPtr graph_recorder = std::make_shared<GraphRecorder>(submodule_name, tag, graph, file_type);
|
||||
graph_recorder->SetDumpFlag(full_name);
|
||||
|
@ -85,6 +90,9 @@ bool RecordAnfGraph(const SubModuleId module, const std::string &tag, const Func
|
|||
|
||||
bool RecordGraphExecOrder(const SubModuleId module, const std::string &tag,
|
||||
const std::vector<CNodePtr> &final_exec_order, int graph_id) {
|
||||
if (!mindspore::RecorderManager::Instance().RdrEnable()) {
|
||||
return false;
|
||||
}
|
||||
std::string submodule_name = std::string(GetSubModuleName(module));
|
||||
GraphExecOrderRecorderPtr graph_exec_order_recorder =
|
||||
std::make_shared<GraphExecOrderRecorder>(submodule_name, tag, final_exec_order, graph_id);
|
||||
|
@ -93,6 +101,9 @@ bool RecordGraphExecOrder(const SubModuleId module, const std::string &tag,
|
|||
}
|
||||
|
||||
bool RecordString(SubModuleId module, const std::string &tag, const std::string &data, const std::string &filename) {
|
||||
if (!mindspore::RecorderManager::Instance().RdrEnable()) {
|
||||
return false;
|
||||
}
|
||||
std::string submodule_name = std::string(GetSubModuleName(module));
|
||||
StringRecorderPtr string_recorder = std::make_shared<StringRecorder>(submodule_name, tag, data, filename);
|
||||
string_recorder->SetFilename(filename);
|
||||
|
@ -102,6 +113,9 @@ bool RecordString(SubModuleId module, const std::string &tag, const std::string
|
|||
|
||||
bool RecordStreamExecOrder(const SubModuleId module, const std::string &tag, const int &graph_id,
|
||||
const std::vector<CNodePtr> &exec_order) {
|
||||
if (!mindspore::RecorderManager::Instance().RdrEnable()) {
|
||||
return false;
|
||||
}
|
||||
std::string submodule_name = std::string(GetSubModuleName(module));
|
||||
StreamExecOrderRecorderPtr stream_exec_order_recorder =
|
||||
std::make_shared<StreamExecOrderRecorder>(submodule_name, tag, graph_id, exec_order);
|
||||
|
@ -113,68 +127,5 @@ void TriggerAll() { mindspore::RecorderManager::Instance().TriggerAll(); }
|
|||
|
||||
void ClearAll() { mindspore::RecorderManager::Instance().ClearAll(); }
|
||||
|
||||
#else
|
||||
bool RecordAnfGraph(const SubModuleId module, const std::string &tag, const FuncGraphPtr &graph, bool full_name,
|
||||
const std::string &file_type) {
|
||||
static bool already_printed = false;
|
||||
std::string submodule_name = std::string(GetSubModuleName(module));
|
||||
if (already_printed) {
|
||||
return false;
|
||||
}
|
||||
already_printed = true;
|
||||
MS_LOG(WARNING) << "The RDR presently only support linux os " << submodule_name;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool RecordGraphExecOrder(const SubModuleId module, const std::string &tag,
|
||||
const std::vector<CNodePtr> &final_exec_order, int graph_id) {
|
||||
static bool already_printed = false;
|
||||
if (already_printed) {
|
||||
return false;
|
||||
}
|
||||
already_printed = true;
|
||||
MS_LOG(WARNING) << "The RDR presently only support linux os.";
|
||||
return false;
|
||||
}
|
||||
|
||||
bool RecordString(SubModuleId module, const std::string &tag, const std::string &data, const std::string &filename) {
|
||||
static bool already_printed = false;
|
||||
if (already_printed) {
|
||||
return false;
|
||||
}
|
||||
already_printed = true;
|
||||
MS_LOG(WARNING) << "The RDR presently only support linux os.";
|
||||
return false;
|
||||
}
|
||||
|
||||
bool RecordStreamExecOrder(const SubModuleId module, const std::string &tag, const int &graph_id,
|
||||
const std::vector<CNodePtr> &exec_order) {
|
||||
static bool already_printed = false;
|
||||
if (already_printed) {
|
||||
return false;
|
||||
}
|
||||
already_printed = true;
|
||||
MS_LOG(WARNING) << "The RDR presently only support linux os.";
|
||||
return false;
|
||||
}
|
||||
|
||||
void TriggerAll() {
|
||||
static bool already_printed = false;
|
||||
if (already_printed) {
|
||||
return;
|
||||
}
|
||||
already_printed = true;
|
||||
MS_LOG(WARNING) << "The RDR presently only support linux os.";
|
||||
}
|
||||
|
||||
void ClearAll() {
|
||||
static bool already_printed = false;
|
||||
if (already_printed) {
|
||||
return;
|
||||
}
|
||||
already_printed = true;
|
||||
MS_LOG(WARNING) << "The RDR presently only support linux os.";
|
||||
}
|
||||
#endif // __linux__
|
||||
} // namespace RDR
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -98,7 +98,6 @@ class StreamExecOrderRecorder : public BaseRecorder {
|
|||
exec_order_.push_back(std::move(exec_node_ptr));
|
||||
}
|
||||
}
|
||||
void SetModule(const std::string &module) { module_ = module; }
|
||||
virtual void Export();
|
||||
|
||||
private:
|
||||
|
|
|
@ -26,14 +26,14 @@ class StringRecorder : public BaseRecorder {
|
|||
StringRecorder() : BaseRecorder() {}
|
||||
StringRecorder(const std::string &module, const std::string &tag, const std::string &data,
|
||||
const std::string &filename)
|
||||
: BaseRecorder(module, tag), data_(data), filename_(filename) {}
|
||||
: BaseRecorder(module, tag), data_(data) {
|
||||
SetFilename(filename);
|
||||
}
|
||||
~StringRecorder() {}
|
||||
void SetModule(const std::string &module) { module_ = module; }
|
||||
virtual void Export();
|
||||
|
||||
private:
|
||||
std::string data_;
|
||||
std::string filename_;
|
||||
};
|
||||
using StringRecorderPtr = std::shared_ptr<StringRecorder>;
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -258,8 +258,8 @@ class _Context:
|
|||
def set_env_config_path(self, env_config_path):
|
||||
"""Check and set env_config_path."""
|
||||
if not self._context_handle.enable_dump_ir():
|
||||
raise ValueError("The 'env_config_path' is not supported, please turn on ENABLE_DUMP_IR "
|
||||
"and recompile source to enable it.")
|
||||
raise ValueError("The 'env_config_path' is not supported, please enable ENABLE_DUMP_IR "
|
||||
"with '-D on' and recompile source.")
|
||||
env_config_path = os.path.realpath(env_config_path)
|
||||
if not os.path.isfile(env_config_path):
|
||||
raise ValueError("The %r set by 'env_config_path' should be an existing json file." % env_config_path)
|
||||
|
|
Loading…
Reference in New Issue