optimize rdr mainly to reduce the redundant code

1. move common function to base recorder
2. add UpdataRdrEnable() for recorder manager, and update once
3. add delimiter virable for base recorder, set it as '.'
4. modify recorders to use the suffix parameter of GetFileRealPath()
5. improve the text in ms_context

Can not add RdrEnablePlatform because linux is supported only.
If os is linux, (!is_rdr_supported && rdr_enable) will be false always,
else, (is_rdr_supported && !rdr_enable) will be false always.
This commit is contained in:
luopengting 2021-02-23 17:18:24 +08:00
parent 692d158f5c
commit 7914363d25
14 changed files with 70 additions and 82 deletions

View File

@ -104,6 +104,8 @@ void EnvConfigParser::ParseRdrSetting(const nlohmann::json &content) {
return;
}
has_rdr_setting_ = true;
auto rdr_enable = CheckJsonKeyExist(*rdr_setting, kRdrSettings, kEnable);
if (rdr_enable.has_value()) {
ParseRdrEnable(**rdr_enable);

View File

@ -27,11 +27,14 @@ class EnvConfigParser {
public:
static EnvConfigParser &GetInstance() {
static EnvConfigParser instance;
instance.Parse();
return instance;
}
void Parse();
std::string config_path() const { return config_file_; }
bool has_rdr_setting() const { return has_rdr_setting_; }
bool rdr_enabled() const { return rdr_enabled_; }
std::string rdr_path() const { return rdr_path_; }
@ -42,7 +45,9 @@ class EnvConfigParser {
std::mutex lock_;
std::string config_file_{""};
bool already_parsed_{false};
bool rdr_enabled_{false};
bool has_rdr_setting_{false};
std::string rdr_path_{"./rdr/"};
std::string GetIfstreamString(const std::ifstream &ifstream);

View File

@ -39,13 +39,13 @@ void BaseRecorder::SetFilename(const std::string &filename) {
std::optional<std::string> BaseRecorder::GetFileRealPath(const std::string &suffix) {
if (filename_.empty()) {
filename_ = module_ + "_" + tag_;
filename_ = module_ + delimiter_ + tag_;
if (!suffix.empty()) {
filename_ += "_" + suffix;
filename_ += delimiter_ + suffix;
}
filename_ += "_" + timestamp_;
filename_ += delimiter_ + timestamp_;
} else if (!suffix.empty()) {
filename_ += "_" + suffix;
filename_ += delimiter_ + suffix;
}
std::string file_path = directory_ + filename_;
auto realpath = Common::GetRealPath(file_path);

View File

@ -65,6 +65,7 @@ class BaseRecorder {
void SetDirectory(const std::string &directory);
void SetFilename(const std::string &filename);
void SetModule(const std::string &module) { module_ = module; }
virtual void Export() {}
protected:
@ -73,6 +74,7 @@ class BaseRecorder {
std::string directory_;
std::string filename_;
std::string timestamp_; // year,month,day,hour,minute,second
std::string delimiter_{"."};
};
using BaseRecorderPtr = std::shared_ptr<BaseRecorder>;
} // namespace mindspore

View File

@ -46,11 +46,11 @@ bool DumpGraphExeOrder(const std::string &filename, const std::vector<CNodePtr>
} // namespace
void GraphExecOrderRecorder::Export() {
auto realpath = GetFileRealPath();
auto realpath = GetFileRealPath(std::to_string(graph_id_));
if (!realpath.has_value()) {
return;
}
std::string real_file_path = realpath.value() + std::to_string(graph_id_);
std::string real_file_path = realpath.value() + ".txt";
DumpGraphExeOrder(real_file_path, exec_order_);
}
} // namespace mindspore

View File

@ -30,7 +30,6 @@ class GraphExecOrderRecorder : public BaseRecorder {
GraphExecOrderRecorder(const std::string &module, const std::string &tag,
const std::vector<CNodePtr> &final_exec_order, int graph_id)
: BaseRecorder(module, tag), exec_order_(final_exec_order), graph_id_(graph_id) {}
void SetModule(const std::string &module) { module_ = module; }
void SetExecOrder(const std::vector<CNodePtr> &final_exec_order) { exec_order_ = final_exec_order; }
virtual void Export();

View File

@ -65,6 +65,7 @@ void GraphRecorder::Export() {
}
std::string suffix = graph_id >= 0 ? std::to_string(graph_id) : "";
auto tmp_realpath = GetFileRealPath(suffix);
if (!tmp_realpath.has_value()) {
return;
}

View File

@ -31,7 +31,6 @@ class GraphRecorder : public BaseRecorder {
const std::string &file_type)
: BaseRecorder(module, tag), func_graph_(graph), graph_type_(file_type) {}
~GraphRecorder() {}
void SetModule(const std::string &module) { module_ = module; }
void SetGraphType(const std::string &file_type) { graph_type_ = file_type; }
void SetFuncGraph(const FuncGraphPtr &func_graph) { func_graph_ = func_graph; }
void SetDumpFlag(bool full_name) { full_name_ = full_name; }

View File

@ -21,7 +21,35 @@
#include "mindspore/core/ir/func_graph.h"
namespace mindspore {
void RecorderManager::UpdateRdrEnable() {
static bool updated = false;
if (updated) {
return;
}
auto &config_parser = mindspore::EnvConfigParser::GetInstance();
rdr_enable_ = config_parser.rdr_enabled();
if (config_parser.has_rdr_setting()) {
#ifdef __linux__
if (!rdr_enable_) {
MS_LOG(WARNING) << "Please set the 'enable' as true using 'rdr' setting in file '" << config_parser.config_path()
<< "' if you want to use RDR.";
}
#else
if (rdr_enable_) {
MS_LOG(WARNING) << "The RDR only supports linux os currently.";
}
rdr_enable_ = false;
#endif
}
updated = true;
}
bool RecorderManager::RecordObject(const BaseRecorderPtr &recorder) {
if (!rdr_enable_) {
return false;
}
if (recorder == nullptr) {
MS_LOG(ERROR) << "register recorder module with nullptr.";
return false;
@ -33,10 +61,7 @@ bool RecorderManager::RecordObject(const BaseRecorderPtr &recorder) {
}
void RecorderManager::TriggerAll() {
auto &config_parser_ptr = mindspore::EnvConfigParser::GetInstance();
config_parser_ptr.Parse();
if (!config_parser_ptr.rdr_enabled()) {
MS_LOG(INFO) << "RDR is not enable.";
if (!rdr_enable_) {
return;
}

View File

@ -31,9 +31,12 @@ class RecorderManager {
public:
static RecorderManager &Instance() {
static RecorderManager manager;
manager.UpdateRdrEnable();
return manager;
}
void UpdateRdrEnable();
bool RdrEnable() const { return rdr_enable_; }
bool RecordObject(const BaseRecorderPtr &recorder);
void TriggerAll();
void ClearAll();
@ -42,6 +45,8 @@ class RecorderManager {
RecorderManager() {}
~RecorderManager() {}
bool rdr_enable_{false};
mutable std::mutex mtx_;
// module, BaserRecorderPtrList
std::unordered_map<std::string, BaseRecorderPtrList> recorder_container_;

View File

@ -65,6 +65,9 @@ namespace RDR {
#ifdef ENABLE_D
bool RecordTaskDebugInfo(SubModuleId module, const std::string &tag,
const std::vector<TaskDebugInfoPtr> &task_debug_info_list, int graph_id) {
if (!mindspore::RecorderManager::Instance().RdrEnable()) {
return false;
}
std::string submodule_name = std::string(GetSubModuleName(module));
TaskDebugInfoRecorderPtr task_debug_info_recorder =
std::make_shared<TaskDebugInfoRecorder>(submodule_name, tag, task_debug_info_list, graph_id);
@ -73,9 +76,11 @@ bool RecordTaskDebugInfo(SubModuleId module, const std::string &tag,
}
#endif // ENABLE_D
#ifdef __linux__
bool RecordAnfGraph(const SubModuleId module, const std::string &tag, const FuncGraphPtr &graph, bool full_name,
const std::string &file_type) {
if (!mindspore::RecorderManager::Instance().RdrEnable()) {
return false;
}
std::string submodule_name = std::string(GetSubModuleName(module));
GraphRecorderPtr graph_recorder = std::make_shared<GraphRecorder>(submodule_name, tag, graph, file_type);
graph_recorder->SetDumpFlag(full_name);
@ -85,6 +90,9 @@ bool RecordAnfGraph(const SubModuleId module, const std::string &tag, const Func
bool RecordGraphExecOrder(const SubModuleId module, const std::string &tag,
const std::vector<CNodePtr> &final_exec_order, int graph_id) {
if (!mindspore::RecorderManager::Instance().RdrEnable()) {
return false;
}
std::string submodule_name = std::string(GetSubModuleName(module));
GraphExecOrderRecorderPtr graph_exec_order_recorder =
std::make_shared<GraphExecOrderRecorder>(submodule_name, tag, final_exec_order, graph_id);
@ -93,6 +101,9 @@ bool RecordGraphExecOrder(const SubModuleId module, const std::string &tag,
}
bool RecordString(SubModuleId module, const std::string &tag, const std::string &data, const std::string &filename) {
if (!mindspore::RecorderManager::Instance().RdrEnable()) {
return false;
}
std::string submodule_name = std::string(GetSubModuleName(module));
StringRecorderPtr string_recorder = std::make_shared<StringRecorder>(submodule_name, tag, data, filename);
string_recorder->SetFilename(filename);
@ -102,6 +113,9 @@ bool RecordString(SubModuleId module, const std::string &tag, const std::string
bool RecordStreamExecOrder(const SubModuleId module, const std::string &tag, const int &graph_id,
const std::vector<CNodePtr> &exec_order) {
if (!mindspore::RecorderManager::Instance().RdrEnable()) {
return false;
}
std::string submodule_name = std::string(GetSubModuleName(module));
StreamExecOrderRecorderPtr stream_exec_order_recorder =
std::make_shared<StreamExecOrderRecorder>(submodule_name, tag, graph_id, exec_order);
@ -113,68 +127,5 @@ void TriggerAll() { mindspore::RecorderManager::Instance().TriggerAll(); }
void ClearAll() { mindspore::RecorderManager::Instance().ClearAll(); }
#else
bool RecordAnfGraph(const SubModuleId module, const std::string &tag, const FuncGraphPtr &graph, bool full_name,
const std::string &file_type) {
static bool already_printed = false;
std::string submodule_name = std::string(GetSubModuleName(module));
if (already_printed) {
return false;
}
already_printed = true;
MS_LOG(WARNING) << "The RDR presently only support linux os " << submodule_name;
return false;
}
bool RecordGraphExecOrder(const SubModuleId module, const std::string &tag,
const std::vector<CNodePtr> &final_exec_order, int graph_id) {
static bool already_printed = false;
if (already_printed) {
return false;
}
already_printed = true;
MS_LOG(WARNING) << "The RDR presently only support linux os.";
return false;
}
bool RecordString(SubModuleId module, const std::string &tag, const std::string &data, const std::string &filename) {
static bool already_printed = false;
if (already_printed) {
return false;
}
already_printed = true;
MS_LOG(WARNING) << "The RDR presently only support linux os.";
return false;
}
bool RecordStreamExecOrder(const SubModuleId module, const std::string &tag, const int &graph_id,
const std::vector<CNodePtr> &exec_order) {
static bool already_printed = false;
if (already_printed) {
return false;
}
already_printed = true;
MS_LOG(WARNING) << "The RDR presently only support linux os.";
return false;
}
void TriggerAll() {
static bool already_printed = false;
if (already_printed) {
return;
}
already_printed = true;
MS_LOG(WARNING) << "The RDR presently only support linux os.";
}
void ClearAll() {
static bool already_printed = false;
if (already_printed) {
return;
}
already_printed = true;
MS_LOG(WARNING) << "The RDR presently only support linux os.";
}
#endif // __linux__
} // namespace RDR
} // namespace mindspore

View File

@ -98,7 +98,6 @@ class StreamExecOrderRecorder : public BaseRecorder {
exec_order_.push_back(std::move(exec_node_ptr));
}
}
void SetModule(const std::string &module) { module_ = module; }
virtual void Export();
private:

View File

@ -26,14 +26,14 @@ class StringRecorder : public BaseRecorder {
StringRecorder() : BaseRecorder() {}
StringRecorder(const std::string &module, const std::string &tag, const std::string &data,
const std::string &filename)
: BaseRecorder(module, tag), data_(data), filename_(filename) {}
: BaseRecorder(module, tag), data_(data) {
SetFilename(filename);
}
~StringRecorder() {}
void SetModule(const std::string &module) { module_ = module; }
virtual void Export();
private:
std::string data_;
std::string filename_;
};
using StringRecorderPtr = std::shared_ptr<StringRecorder>;
} // namespace mindspore

View File

@ -258,8 +258,8 @@ class _Context:
def set_env_config_path(self, env_config_path):
"""Check and set env_config_path."""
if not self._context_handle.enable_dump_ir():
raise ValueError("The 'env_config_path' is not supported, please turn on ENABLE_DUMP_IR "
"and recompile source to enable it.")
raise ValueError("The 'env_config_path' is not supported, please enable ENABLE_DUMP_IR "
"with '-D on' and recompile source.")
env_config_path = os.path.realpath(env_config_path)
if not os.path.isfile(env_config_path):
raise ValueError("The %r set by 'env_config_path' should be an existing json file." % env_config_path)